Skip to content

Commit 715894a

Browse files
authored
Adopt PreTrainedModelWrapper for Hugging Face models (#215)
* Adopt `PreTrainedModelWrapper` for Hugging Face models * Adopt `PreTrainedModelWrapper` for Hugging Face models * Update documentation * Fix up broken merge * Run pre-commit * Revert dtype change to `ILQLHead` * Fix isort * Format again... * Revert newline deletion * Revert unrelated changes and update docs * Update `README.md` saving example * Revert unrelated changes * Fix `dtype` access and hydra `return_dict` * Force ref models into eval mode * Add unit tests for `AutoModel...`s * Commit work on fixing `T5Branch` * refactor(models): move models out of trainer dir * refactor(sft): remove `save_pretrained` override * Run pre-commit * Ignore line length for links * Revert naming to `base_model` * Rename hydra models for clarity * Add `from_config` support * cleanup docstrings * Revert T5 branch changes * Remove variadic params
1 parent 93c90cb commit 715894a

20 files changed

+1191
-749
lines changed

README.md

+2-4
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'),
6363
trainer.save_pretrained('/path/to/output/folder/')
6464
```
6565

66-
🩹 Warning: Only the `AcceleratePPOTrainer` can write HuggingFace transformers to disk with `save_pretrained` at the moment, as ILQL trainers require inference behavior currently unsupported by available `transformers` architectures.
67-
6866
#### Use 🤗 Accelerate to launch distributed training
6967

7068
```bash
@@ -74,13 +72,13 @@ accelerate launch examples/simulacra.py
7472

7573
#### Use NeMo-Megatron to launch distributed training
7674

77-
Follow the setup instructions in the [NeMo README](./trlx/trainer/nemo).
75+
Follow the setup instructions in the [NeMo README](./trlx/models/).
7876

7977
```bash
8078
python examples/nemo_ilql_sentiments.py
8179
```
8280

83-
For more usage see the [NeMo README](./trlx/trainer/nemo)
81+
For more usage see the [NeMo README](./trlx/models)
8482

8583
#### Use Ray Tune to launch hyperparameter sweep
8684

docs/source/trainer.rst

-15
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,7 @@ Note that new trainers must be registered with ``trlx.trainer.register_trainer``
1919
.. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer
2020
:members:
2121

22-
.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMWithValueHead
23-
:members:
24-
25-
.. autoclass:: trlx.trainer.nn.ppo_models.GPTModelBranch
26-
:members:
27-
28-
.. autoclass:: trlx.trainer.nn.ppo_models.OPTModelBranch
29-
:members:
30-
31-
.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMHydraWithValueHead
32-
:members:
33-
3422
**ILQL**
3523

3624
.. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer
3725
:members:
38-
39-
.. autoclass:: trlx.trainer.nn.ilql_models.CausalLMWithValueHeads
40-
:members:

tests/test_models.py

+340
Large diffs are not rendered by default.

tests/test_ppo.py

-84
This file was deleted.

tests/test_utils.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import unittest
2+
13
import accelerate
24
import pytest
35
import torch
@@ -68,9 +70,9 @@ def test_hf_attr_getters(model_name: str):
6870
arch = transformers.AutoModelForCausalLM.from_config(config)
6971

7072
arch_getters = [
71-
modeling_utils.hf_get_causal_base_model,
72-
modeling_utils.hf_get_causal_final_norm,
73-
modeling_utils.hf_get_causal_hidden_layers,
73+
modeling_utils.hf_get_decoder,
74+
modeling_utils.hf_get_decoder_final_norm,
75+
modeling_utils.hf_get_decoder_blocks,
7476
modeling_utils.hf_get_lm_head,
7577
]
7678
for get in arch_getters:
@@ -125,3 +127,23 @@ def test_parse_delta_kwargs(model_name):
125127
)
126128
for kwarg_mod in delta_kwargs["modified_modules"]:
127129
assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']"
130+
131+
132+
class TestStatistics(unittest.TestCase):
133+
@classmethod
134+
def setUpClass(cls):
135+
cls.m = modeling_utils.RunningMoments()
136+
cls.a1 = torch.arange(100, dtype=float)
137+
cls.a2 = torch.ones(100, dtype=float)
138+
cls.a3 = torch.exp(torch.arange(10, dtype=float))
139+
cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float)
140+
141+
def test_running_moments(self):
142+
assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6)
143+
assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6)
144+
assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6)
145+
assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6)
146+
147+
a = torch.hstack((self.a1, self.a2, self.a3, self.a4))
148+
assert torch.isclose(self.m.mean, a.mean(), atol=1e-6)
149+
assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6)
File renamed without changes.
File renamed without changes.

trlx/models/modeling_base.py

+223
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# NOTE: This file contains a modified version of the `PreTrainedModelWrapper` class from
16+
# HuggingFace's `trl` library. The original source code can be found here:
17+
# https://github.com/lvwerra/trl/blob/78c13226bf8ea1ccd9b1c091f03a938098521f6c/trl/models/modeling_base.py
18+
19+
import inspect
20+
import json
21+
import os
22+
from typing import Any, Dict, List, Optional, Union
23+
24+
import torch
25+
import torch.nn as nn
26+
import transformers
27+
from huggingface_hub import hf_hub_download
28+
29+
30+
class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin):
31+
"""A wrapper around `transformers.PreTrainedModel`
32+
33+
Reference: @younesbelkada's `PreTrainedModelWrapper`
34+
https://github.com/lvwerra/trl/blob/4f5c16fafde42d9aca971952bcdcc1f5a0a68cf0/trl/models/modeling_base.py#L2
35+
36+
Attributes:
37+
_auto_model_parent_class (transformers.AutoModel): The `transformers.AutoModel`
38+
type to base the wrapping behavior off of, e.g. `transformers.AutoModelForCausalLM`.
39+
_supported_modules (List[str]): A list of attribute names for modules of
40+
the underlying architecture model. This is used, for example, to save
41+
and load any additional modules by manipulating the state dict.
42+
_supported_args (List[str]): A list of arguments specific to the underlying
43+
architecture to separate from arguments that are supported by the
44+
parent `AutoModel` class. Any arguments that are not supported by the
45+
underlying model will be passed to the parent `AutoModel` class.
46+
"""
47+
48+
_auto_model_parent_class: transformers.AutoModel = None
49+
_supported_modules: List[str] = None
50+
# TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the
51+
# specific underlying type similar to how config instances can be used to instantiate
52+
# `transformers.PreTrainedModel`s.
53+
_supported_args: List[str] = None
54+
55+
def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs):
56+
super().__init__()
57+
self.base_model = base_model
58+
# cache `forward` args for general use (avoids incompatible args across architectures)
59+
self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args
60+
61+
@classmethod
62+
def _split_kwargs(cls, kwargs: Dict[str, Any]):
63+
"""Separates the kwargs from the supported arguments within `supported_args`
64+
and those that are not
65+
"""
66+
supported_kwargs = {}
67+
unsupported_kwargs = {}
68+
for key, value in kwargs.items():
69+
if key in cls._supported_args:
70+
supported_kwargs[key] = value
71+
else:
72+
unsupported_kwargs[key] = value
73+
return supported_kwargs, unsupported_kwargs
74+
75+
@classmethod
76+
def from_config(cls, config: transformers.PretrainedConfig, **kwargs):
77+
"""Instantiate the pretrained pytorch model from a configuration.
78+
79+
Args:
80+
config (transformers.PretrainedConfig): The configuration to use to
81+
instantiate the base model.
82+
83+
NOTE: Loading a model from its configuration file does **not** load the
84+
model weights. It only affects the model's configuration. Use
85+
`~transformers.AutoModel.from_pretrained` to load the model weights.
86+
"""
87+
if kwargs is not None:
88+
wrapped_model_kwargs, from_config_kwargs = cls._split_kwargs(kwargs)
89+
else:
90+
from_config_kwargs = {}
91+
wrapped_model_kwargs = {}
92+
base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs)
93+
model = cls(base_model, **wrapped_model_kwargs)
94+
return model
95+
96+
@classmethod
97+
def from_pretrained( # noqa: max-complexity
98+
cls,
99+
pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel],
100+
*model_args,
101+
**kwargs,
102+
):
103+
"""Instantiate a pretrained pytorch model from a pretrained model configuration.
104+
This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`.
105+
Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained`
106+
for more information.
107+
108+
Args:
109+
pretrained_model_name_or_path (str or `transformers.PreTrainedModel`):
110+
The identifier of the pretrained model to load or the pretrained model itself.
111+
*model_args (sequence of positional arguments, *optional*):
112+
All remaining positional arguments will be passed to the `_auto_model_parent_class`.
113+
**kwargs (dict, *optional*):
114+
Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class`
115+
call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific
116+
instance of the wrapped model.
117+
118+
NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.
119+
"""
120+
if kwargs is not None:
121+
wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs)
122+
else:
123+
from_pretrained_kwargs = {}
124+
wrapped_model_kwargs = {}
125+
126+
if isinstance(pretrained_model_name_or_path, str):
127+
# Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM)
128+
base_model = cls._auto_model_parent_class.from_pretrained(
129+
pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs
130+
)
131+
elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel):
132+
base_model = pretrained_model_name_or_path
133+
else:
134+
raise ValueError(
135+
f"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}"
136+
"Expected `str` or `transformers.PreTrainedModel`."
137+
)
138+
139+
model = cls(base_model, **wrapped_model_kwargs)
140+
141+
if isinstance(pretrained_model_name_or_path, str):
142+
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
143+
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
144+
is_sharded = False
145+
146+
if not os.path.exists(filename):
147+
try:
148+
filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
149+
# Sharded
150+
except Exception:
151+
if os.path.exists(sharded_index_filename):
152+
index_file_name = sharded_index_filename
153+
else:
154+
index_file_name = hf_hub_download(
155+
pretrained_model_name_or_path,
156+
"pytorch_model.bin.index.json",
157+
)
158+
with open(index_file_name, "r") as f:
159+
index = json.load(f)
160+
# Collect files containing weights from supported modules
161+
files_to_download = set()
162+
for k, v in index["weight_map"].items():
163+
if any([module in k for module in cls._supported_modules]):
164+
files_to_download.add(v)
165+
is_sharded = True
166+
167+
if is_sharded:
168+
# Merge each shard into a state dict
169+
# TODO: Optimize this to avoid wasting RAM
170+
state_dict = {}
171+
for shard_file in files_to_download:
172+
filename = os.path.join(pretrained_model_name_or_path, shard_file)
173+
# Download if shard file doesn't exist locally
174+
if not os.path.exists(filename):
175+
filename = hf_hub_download(pretrained_model_name_or_path, shard_file)
176+
state_dict.update(torch.load(filename, map_location="cpu"))
177+
else:
178+
state_dict = torch.load(filename, map_location="cpu")
179+
else:
180+
state_dict = pretrained_model_name_or_path.state_dict()
181+
182+
model.post_init(state_dict=state_dict)
183+
return model
184+
185+
def save_pretrained(self, *args, **kwargs):
186+
"""Save the pretrained model to a directory. This method is a wrapper
187+
around `transformers.PreTrainedModel.save_pretrained`. Please refer to
188+
the documentation of `transformers.PreTrainedModel.save_pretrained` for
189+
more information.
190+
191+
Args:
192+
*args (`list`, *optional*):
193+
Positional arguments passed along to the underlying model's
194+
`save_pretrained` method.
195+
**kwargs (`dict`, *optional*):
196+
Keyword arguments passed along to the underlying model's
197+
`save_pretrained` method.
198+
"""
199+
state_dict = kwargs.pop("state_dict", None)
200+
if state_dict is None:
201+
state_dict = self.state_dict()
202+
kwargs["state_dict"] = state_dict
203+
204+
return self.base_model.save_pretrained(*args, **kwargs)
205+
206+
def state_dict(self, *args, **kwargs):
207+
"""Return the state_dict of the pretrained model."""
208+
raise NotImplementedError
209+
210+
def post_init(self, *args, **kwargs):
211+
"""Post initialization method. This method is called after the model is
212+
instantiated and loaded from a checkpoint. It can be used to perform
213+
additional operations such as loading the state_dict.
214+
"""
215+
raise NotImplementedError
216+
217+
def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:
218+
"""Filter out arguments not supported by the specific instance of
219+
`base_model.transformer.forward`
220+
"""
221+
# FIXME: This is a hack to get around the fact that the `transformers`
222+
# architectures we use don't have a consistent API for `forward` parameters.
223+
return {k: v for k, v in kwargs.items() if k in self.forward_kwargs}

0 commit comments

Comments
 (0)