Skip to content

Commit 16a482c

Browse files
authored
Add pre-commit with black (#36)
* add pre-commit * format
1 parent d90dc88 commit 16a482c

29 files changed

+236
-124
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,4 @@ nbs/wandb/
144144

145145
wandb/
146146

147-
OUT/
147+
OUT/

.pre-commit-config.yaml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# See https://pre-commit.com for more information
2+
# See https://pre-commit.com/hooks.html for more hooks
3+
# This should be the _latest_ version of python supported by us
4+
default_language_version:
5+
python: python3.9
6+
repos:
7+
- repo: https://github.com/pre-commit/pre-commit-hooks
8+
rev: v3.2.0
9+
hooks:
10+
- id: trailing-whitespace
11+
- id: end-of-file-fixer
12+
- id: check-yaml
13+
- repo: https://github.com/psf/black
14+
rev: 22.10.0
15+
hooks:
16+
- id: black
17+
files: ^(trlx|examples|unittests|setup.py)/

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The training pipeline is broken into four pieces:
2222
- Orchestrator: Handles exploration/rollout collection of online methods. Pushes collected rollouts to the rollout pipeline.
2323
- Model: Wraps the supplied base model (ex: `gpt2`) and implements the desired training method loss (ex: PPO).
2424

25-
Adding a task for RLHF training depends on the desired training method and pre-existing data. If we are online and have no reward labeled data this is as simple as writing a new prompt pipeline, which supplies prompts for exploration, and a new reward function to be passed into the `PPOOrchestrator` class.
25+
Adding a task for RLHF training depends on the desired training method and pre-existing data. If we are online and have no reward labeled data this is as simple as writing a new prompt pipeline, which supplies prompts for exploration, and a new reward function to be passed into the `PPOOrchestrator` class.
2626

2727
## Example: How to add a task
2828

configs/ppo_config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ method:
4949
min_length : 48 # LM min sample gen length
5050
top_k : 0.0 # top k
5151
top_p : 1.0 # top p
52-
do_sample : True # sample
52+
do_sample : True # sample

configs/ppo_gptj.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ method:
5050
top_k : 0.0 # top k
5151
top_p : 0.7 # top p
5252
do_sample : True # sample
53-
temperature: 0.5
53+
temperature: 0.5

configs/test_config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ method:
4949
min_length : 48 # LM min sample gen length
5050
top_k : 0.0 # top k
5151
top_p : 1.0 # top p
52-
do_sample : True # sample
52+
do_sample : True # sample

docs/source/configs.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ Configs
66
Training a model in TRL will require you to set several configs:
77
ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like
88
training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for
9-
the specific method being used (i.e. ILQL or PPO)
9+
the specific method being used (i.e. ILQL or PPO)
1010

1111

12-
**General**
12+
**General**
1313

1414
.. autoclass:: trlx.data.configs.TRLConfig
1515
:members:
@@ -21,9 +21,9 @@ the specific method being used (i.e. ILQL or PPO)
2121
:members:
2222

2323
.. autoclass:: trlx.data.method_configs.MethodConfig
24-
:members:
24+
:members:
2525

26-
**PPO**
26+
**PPO**
2727

2828
.. autoclass:: trlx.data.method_configs.PPOConfig
2929
:undoc-members:

docs/source/data.rst

+7-7
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
Data Elements
44
************************
55

6-
All of the major Carper projects: trlX, CHEESE, and magiCARP use
6+
All of the major Carper projects: trlX, CHEESE, and magiCARP use
77
dataclasses corresponding to batches of data to communicate data between models and different
88
components. trlX is no different, though it has many different dataclasses for
99
different components like training or inference. Currently, we support PPO and ILQL, which
1010
each demand different kinds of data during training.
11-
1211

13-
**Basic Data Elements for Accelerate**
14-
12+
13+
**Basic Data Elements for Accelerate**
14+
1515

1616
.. autoclass:: trlx.data.accelerate_base_datatypes.PromptElement
1717
:members:
@@ -25,9 +25,9 @@ each demand different kinds of data during training.
2525
.. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement
2626
:members:
2727

28-
29-
**Data Elements for PPO**
30-
28+
29+
**Data Elements for PPO**
30+
3131
.. autoclass:: trlx.data.ppo_types.PPORLElement
3232
:members:
3333

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
Welcome to trlX's documentation!
77
================================
88
trlX is a library made for training large language models using reinforcement learning. It
9-
currently supports training using PPO or ILQL for models up to 20B using Accelerate.
9+
currently supports training using PPO or ILQL for models up to 20B using Accelerate.
1010

1111
.. toctree::
1212
:maxdepth: 2

docs/source/models.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
RL Models
44
*******************
55

6-
RL Models are what you're training with trlX. Currently, we support PPO and ILQL.
6+
RL Models are what you're training with trlX. Currently, we support PPO and ILQL.
77
Note that new models must be registered with ``trlx.model.register_model``.
8-
9-
**General**
8+
9+
**General**
1010

1111
.. autoclass:: trlx.model.BaseRLModel
1212
:members:
1313

1414
.. autoclass:: trlx.model.accelerate_base_model.AccelerateRLModel
1515
:members:
1616

17-
**PPO**
17+
**PPO**
1818

1919
.. autoclass:: trlx.model.accelerate_ppo_model.AcceleratePPOModel
2020
:members:

docs/source/orchestrator.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ Orchestrators manage reading data from a pipeline and creating RL data elements
77
to push to a models rollout storage. Use the ``trlx.orchestrator.register_orchestrator`` decorator when creating
88
new orchestrators.
99

10-
**General**
10+
**General**
1111

1212
.. autoclass:: trlx.orchestrator.Orchestrator
1313
:members:
1414

15-
**PPO**
15+
**PPO**
1616

1717
.. autoclass:: trlx.orchestrator.ppo_orchestrator.PPOOrchestrator
1818
:members:
1919

20-
**ILQL**
20+
**ILQL**
2121

2222
.. autoclass:: trlx.orchestrator.offline_orchestrator.OfflineOrchestrator
2323
:members:

examples/ilql_randomwalks.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ def reward_fn(samples):
9999
n_layer=4, n_embd=144, vocab_size=logit_mask.shape[0]
100100
)
101101

102-
model = ILQLModel(
103-
config=config, logit_mask=logit_mask
104-
)
102+
model = ILQLModel(config=config, logit_mask=logit_mask)
105103

106104
orch = OfflineOrchestrator(
107105
model=model,

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ datasets==2.4.0
33
deepspeed==0.7.3
44
einops==0.4.1
55
numpy==1.23.2
6+
pre-commit==2.20.0
67
tqdm==4.64.0
78
transformers==4.21.2
89
wandb==0.13.2

setup.cfg

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ author = Alex Havrilla
44
version = 1.0.0
55

66
[options]
7-
install_requires =
7+
install_requires =
88
accelerate
99
datasets
1010
deepspeed
@@ -13,4 +13,4 @@ install_requires =
1313
tqdm
1414
transformers
1515
wandb
16-
torchtyping
16+
torchtyping

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from setuptools import setup
22

3-
setup()
3+
setup()

trlx/data/accelerate_base_datatypes.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ class PromptElement:
1616
:param tokens: The prompt tokens. Should be a long tensor
1717
:type tokens: torch.Tensor
1818
"""
19-
text : str
20-
tokens : TensorType["num_tokens"]
19+
20+
text: str
21+
tokens: TensorType["num_tokens"]
22+
2123

2224
@dataclass
2325
class PromptBatch:
@@ -30,8 +32,10 @@ class PromptBatch:
3032
:param tokens: A long tensor batch of prompt tokens.
3133
:type tokens: torch.Tensor
3234
"""
33-
text : Iterable[str]
34-
tokens : TensorType["batch_size", "num_tokens"]
35+
36+
text: Iterable[str]
37+
tokens: TensorType["batch_size", "num_tokens"]
38+
3539

3640
@dataclass
3741
class AccelerateRLElement:
@@ -44,8 +48,10 @@ class AccelerateRLElement:
4448
:param rewards: The rewards for each token. Should be a float tensor of same size as tokens.
4549
:type rewards: torch.Tensor
4650
"""
47-
output_tokens : TensorType["output_size"]
48-
rewards : TensorType["output_size"]
51+
52+
output_tokens: TensorType["output_size"]
53+
rewards: TensorType["output_size"]
54+
4955

5056
@dataclass
5157
class AccelerateRLBatchElement:
@@ -58,5 +64,6 @@ class AccelerateRLBatchElement:
5864
:param rewards: Batches of float tensors of rewards for each output token.
5965
:type rewards: torch.Tensor
6066
"""
61-
output_tokens : TensorType["batch_size", "output_size"]
62-
rewards : TensorType["batch_size", "output_size"]
67+
68+
output_tokens: TensorType["batch_size", "output_size"]
69+
rewards: TensorType["batch_size", "output_size"]

trlx/data/configs.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ class ModelConfig:
2424
:param device: Device to use when doing single GPU training. Not needed in most cases.
2525
:type device: str
2626
"""
27-
model_path : str
28-
tokenizer_path : str
29-
model_type : str # One of the architectures present in framework.model
30-
device : str = ''
31-
num_layers_unfrozen : int = -1
27+
28+
model_path: str
29+
tokenizer_path: str
30+
model_type: str # One of the architectures present in framework.model
31+
device: str = ""
32+
num_layers_unfrozen: int = -1
3233

3334
@classmethod
3435
def from_dict(cls, config: Dict[str, Any]):
@@ -91,11 +92,12 @@ class TrainConfig:
9192
:param project_name: Project name for wandb
9293
:type project_name: str
9394
"""
94-
n_ctx : int
95-
epochs : int
96-
total_steps : int
97-
batch_size : int
98-
grad_clip : float # Clip grad norms to this value
95+
96+
n_ctx: int
97+
epochs: int
98+
total_steps: int
99+
batch_size: int
100+
grad_clip: float # Clip grad norms to this value
99101

100102
lr_ramp_steps: int
101103
lr_decay_steps: int
@@ -128,9 +130,10 @@ class TRLConfig:
128130
"""
129131
Top level config for trlX. Loads configs and can be converted to dictionary.
130132
"""
131-
model : ModelConfig
132-
train : TrainConfig
133-
method : MethodConfig
133+
134+
model: ModelConfig
135+
train: TrainConfig
136+
method: MethodConfig
134137

135138
@classmethod
136139
def load_yaml(cls, yml_fp: str):

trlx/data/ilql_types.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class ILQLElement:
2020
:param rewards: Rewards for each token. Should be a float tensor of same size as tokens.
2121
:type rewards: torch.Tensor
2222
"""
23+
2324
input_ids: TensorType["query_size"]
2425
attention_mask: TensorType["query_size"]
2526
rewards: TensorType["reward_size"]
@@ -39,6 +40,7 @@ class ILQLBatch:
3940
:param rewards: Batch of rewards for each token in each token batch.
4041
:type rewards: torch.Tensor
4142
"""
43+
4244
input_ids: TensorType["batch_size", "query_size"]
4345
attention_mask: TensorType["batch_size", "query_size"]
4446
rewards: TensorType["batch_size", "reward_size"]

trlx/data/method_configs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class MethodConfig:
5050
:param name: Name of the method
5151
:type name: str
5252
"""
53-
name : str
53+
54+
name: str
5455

5556
@classmethod
5657
def from_dict(cls, config: Dict[str, Any]):

trlx/data/ppo_types.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@ class PPORLElement:
2525
:param rewards: The rewards for each token outputted in response. Should be a float tensor of same size as tokens.
2626
:type rewards: torch.Tensor
2727
"""
28-
query_tensor : TensorType["query_size"]
29-
response_tensor : TensorType["response_size"]
30-
logprobs : TensorType["response_size", "vocab_size"]
31-
values : TensorType["response_size"]
32-
rewards : TensorType["response_size"]
28+
29+
query_tensor: TensorType["query_size"]
30+
response_tensor: TensorType["response_size"]
31+
logprobs: TensorType["response_size", "vocab_size"]
32+
values: TensorType["response_size"]
33+
rewards: TensorType["response_size"]
34+
3335

3436
@dataclass
3537
class PPORLBatch:
@@ -51,8 +53,9 @@ class PPORLBatch:
5153
:param rewards: A batch of rewards
5254
:type rewards: torch.Tensor
5355
"""
54-
query_tensors : TensorType["batch_size", "query_size"]
55-
response_tensors : TensorType["batch_size", "response_size"]
56-
logprobs : TensorType["batch_size", "response_size", "vocab_size"]
57-
values : TensorType["batch_size", "response_size"]
58-
rewards : TensorType["batch_size", "response_size"]
56+
57+
query_tensors: TensorType["batch_size", "query_size"]
58+
response_tensors: TensorType["batch_size", "response_size"]
59+
logprobs: TensorType["batch_size", "response_size", "vocab_size"]
60+
values: TensorType["batch_size", "response_size"]
61+
rewards: TensorType["batch_size", "response_size"]

0 commit comments

Comments
 (0)