Skip to content

Commit 4390d3f

Browse files
[style] Increase black's line length (#250)
* style(*): increase black --line-length to `120` * chore!(pre-commit): update isort to `5.12.0` to resolve discrepancy for some odd reason CI's style check differs from the local `run --all-files` * style(*): satisfy isort
1 parent a92a971 commit 4390d3f

33 files changed

+209
-658
lines changed

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# See https://pre-commit.com/hooks.html for more hooks
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.1.0
5+
rev: v4.4.0
66
hooks:
77
- id: check-case-conflict
88
- id: check-json
@@ -19,12 +19,12 @@ repos:
1919
- id: requirements-txt-fixer
2020
- id: trailing-whitespace
2121
- repo: https://github.com/psf/black
22-
rev: 22.10.0
22+
rev: 23.1.0
2323
hooks:
2424
- id: black
2525
files: ^(trlx|examples|tests|setup.py)/
2626
- repo: https://github.com/pycqa/isort
27-
rev: 5.11.5
27+
rev: 5.12.0
2828
hooks:
2929
- id: isort
3030
name: isort (python)

examples/architext.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def reward_fn(samples, **kwargs):
3838
def main(hparams={}):
3939
config = TRLConfig.update(default_config, hparams)
4040

41-
trlx.train(
42-
"architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config
43-
)
41+
trlx.train("architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config)
4442

4543

4644
if __name__ == "__main__":

examples/experiments/grounded_program_synthesis/lang.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def init_random_input(len_range: int = 5, value_gen=5) -> list:
2121

2222
const_integer = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
2323

24+
2425
# Functions in the DSL
2526
# Each function defines a transformation in the given DSL Grammar.
2627
def take(input_list: list, n: int) -> list:
@@ -372,9 +373,7 @@ def basic_stats(dataset, tokenizer):
372373
"""
373374
length_list = []
374375
for examples in tqdm(dataset):
375-
datapoint = tokenizer(
376-
examples["input"] + " " + examples["output"] + "<|endoftext|>"
377-
)
376+
datapoint = tokenizer(examples["input"] + " " + examples["output"] + "<|endoftext|>")
378377
length_list.append(len(datapoint["input_ids"]))
379378
return {
380379
"max": max(length_list),

examples/experiments/grounded_program_synthesis/train_trlx.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,8 @@ def main(hparams={}):
7575

7676
if __name__ == "__main__":
7777
# TEST REWARD FUNTION
78-
assert (
79-
reward_fn(
80-
["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"]
81-
)
82-
) == [1]
83-
assert (
84-
reward_fn(
85-
["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"]
86-
)
87-
) == [-1]
88-
assert (
89-
reward_fn(
90-
["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"]
91-
)
92-
) == [-0.5]
78+
assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"])) == [1]
79+
assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"])) == [-1]
80+
assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"])) == [-0.5]
9381

9482
main()

examples/randomwalks/randomwalks.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import torch
66

77

8-
def generate_rand_int_excluding(
9-
rng: np.random.RandomState, max: int, exclude: int
10-
) -> int:
8+
def generate_rand_int_excluding(rng: np.random.RandomState, max: int, exclude: int) -> int:
119
"""Random integer generator, excluding a specific number
1210
1311
Args:
@@ -35,12 +33,7 @@ def generate_random_walks( # noqa: max-complexity
3533
p_edge: float = 0.1,
3634
seed: int = 1002,
3735
gpt2_tokenizer: bool = False,
38-
) -> Tuple[
39-
Callable[[List[str]], Dict[str, List[float]]],
40-
List[str],
41-
List[str],
42-
torch.Tensor,
43-
]:
36+
) -> Tuple[Callable[[List[str]], Dict[str, List[float]]], List[str], List[str], torch.Tensor,]:
4437
"""Generate random walks
4538
4639
Args:
@@ -106,7 +99,6 @@ def generate_random_walks( # noqa: max-complexity
10699

107100
# Create n_walks samples
108101
for _ in range(n_walks):
109-
110102
# Create a random starting node (that isn't already at the goal state)
111103
node: int = generate_rand_int_excluding(rng, n_nodes, goal)
112104

@@ -116,7 +108,6 @@ def generate_random_walks( # noqa: max-complexity
116108
# Do a series of steps, until we hit the maximum number of steps or the
117109
# goal state (whichever comes first)
118110
for _step in range(max_length - 1):
119-
120111
# From the starting node, get all the nodes we can move to. Pick one
121112
# of these at random, and add it to the list of visited nodes
122113
node = rng.choice(np.nonzero(adjacency_matrix[node])[0])
@@ -143,9 +134,7 @@ def generate_random_walks( # noqa: max-complexity
143134
for start in set(range(n_nodes)) - {goal}:
144135
try:
145136
# Find the shortest path (up to the max_length)
146-
shortest_path = nx.shortest_path(directional_graph, start, goal)[
147-
:max_length
148-
]
137+
shortest_path = nx.shortest_path(directional_graph, start, goal)[:max_length]
149138
shortest_lengths.append(len(shortest_path))
150139
except Exception:
151140
# If there is no path, use the maximum length instead
@@ -186,11 +175,7 @@ def metric_fn(
186175
for node in range(len(sample)):
187176
# If an invalid path is taken, set the length to the invalid
188177
# path score
189-
if (
190-
sample[node] >= n_nodes
191-
or node > 0
192-
and not adjacency_matrix[sample[node - 1], sample[node]]
193-
):
178+
if sample[node] >= n_nodes or node > 0 and not adjacency_matrix[sample[node - 1], sample[node]]:
194179
length = invalid_path_length
195180
break
196181

examples/summarize_daily_cnn/t5_summarize_daily_cnn.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import evaluate
1313
except ImportError:
1414
raise ImportError(
15-
"To run this example, please install the `evaluate` and `nltk` packages"
16-
"by running `pip install evaluate`"
15+
"To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`"
1716
)
1817

1918
config_path = pathlib.Path(__file__).parent / "configs/ppo_config_cnn_daily.yml"
@@ -26,9 +25,7 @@
2625
def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
2726
original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]
2827
scores = [
29-
meteor.compute(predictions=[output.strip()], references=[original])[
30-
"meteor"
31-
]
28+
meteor.compute(predictions=[output.strip()], references=[original])["meteor"]
3229
for (original, output) in zip(original_summaries, outputs)
3330
]
3431
return scores
@@ -41,9 +38,7 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
4138
prompts = ["Summarize: " + prompt for prompt in prompts]
4239

4340
# take 1,000 samples from the validation set as prompts for evaluation
44-
val_prompts = [
45-
"Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]
46-
]
41+
val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]]
4742
val_summaries = dataset["validation"]["highlights"][0:1000]
4843

4944
# make dictionary of prompts and labels to use for reward function
@@ -63,9 +58,7 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
6358

6459
for i in tqdm(range(len(val_prompts))):
6560
key = tokenizer.decode(
66-
tokenizer(val_prompts[i], truncation=True, max_length=max_length)[
67-
"input_ids"
68-
],
61+
tokenizer(val_prompts[i], truncation=True, max_length=max_length)["input_ids"],
6962
skip_special_tokens=True,
7063
) # get prompt like trlx's prompt
7164
prompt_label[key.strip()] = val_summaries[i]

examples/summarize_rlhf/reward_model/gptj_reward_test.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ def set_seed(seed_val=42):
1616
torch.cuda.manual_seed_all(seed_val)
1717

1818

19-
def create_comparison_dataset(
20-
path="CarperAI/openai_summarize_comparisons", split="train"
21-
):
19+
def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"):
2220
dataset = load_dataset(path, split=split)
2321
if split == "test":
2422
dataset = dataset.select(range(5000))
@@ -95,16 +93,12 @@ def __call__(self, data):
9593
model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft")
9694
model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin"))
9795
max_length = 550
98-
val_pairs = create_comparison_dataset(
99-
"CarperAI/openai_summarize_comparisons", "test"
100-
)
96+
val_pairs = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test")
10197
dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length)
10298

10399
from torch.utils.data import DataLoader
104100

105-
dev_dataloader = DataLoader(
106-
dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward()
107-
)
101+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward())
108102
model.cuda()
109103
model.eval()
110104
model.half()
@@ -116,9 +110,7 @@ def __call__(self, data):
116110
for x in batch:
117111
batch[x] = batch[x].cuda()
118112
outputs = model(**batch)
119-
correct += sum(
120-
outputs["chosen_end_scores"] > outputs["rejected_end_scores"]
121-
)
113+
correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"])
122114
chosen_list.append(outputs["chosen_end_scores"].cpu())
123115
reject_list.append(outputs["rejected_end_scores"].cpu())
124116
print("Total accuracy: ", correct / len(dev_dataset))

examples/summarize_rlhf/reward_model/reward_model.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ def __init__(self, model_path):
99
model = AutoModelForCausalLM.from_pretrained(model_path)
1010
self.config = model.config
1111
# `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd``
12-
self.config.n_embd = (
13-
self.config.hidden_size
14-
if hasattr(self.config, "hidden_size")
15-
else self.config.n_embd
16-
)
12+
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
1713
self.transformer = model.transformer
1814
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
1915
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
@@ -91,9 +87,7 @@ def forward(
9187
rejected_end_scores.append(r_truncated_reward[-1])
9288

9389
# Compute loss
94-
loss += -torch.log(
95-
torch.sigmoid(c_truncated_reward - r_truncated_reward)
96-
).mean()
90+
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
9791
loss = loss / bs
9892

9993
if not inference:

examples/summarize_rlhf/reward_model/train_reward_model_gptj.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from transformers import AutoTokenizer, Trainer, TrainingArguments
99

1010

11-
def create_comparison_dataset(
12-
path="CarperAI/openai_summarize_comparisons", split="train"
13-
):
11+
def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"):
1412
dataset = load_dataset(path, split=split)
1513
pairs = []
1614
for sample in tqdm(dataset):

examples/summarize_rlhf/sft/summarize_dataset.py

+7-19
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def __len__(self):
4343

4444
def __getitem__(self, idx):
4545
txt = self.post_list[idx]
46-
encodings_dict = self.tokenizer(
47-
txt, truncation=True, max_length=self.max_length, padding="max_length"
48-
)
46+
encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
4947
input_ids = torch.tensor(encodings_dict["input_ids"])
5048
attn_masks = torch.tensor(encodings_dict["attention_mask"])
5149

@@ -75,19 +73,11 @@ def make_text(post, summarize):
7573
self.post_list.append(sample["info"]["post"])
7674
# NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]`
7775
if sample["choice"] == 0:
78-
self.summaries_0.append(
79-
make_text(sample["info"], sample["summaries"][0]["text"])
80-
)
81-
self.summaries_1.append(
82-
make_text(sample["info"], sample["summaries"][1]["text"])
83-
)
76+
self.summaries_0.append(make_text(sample["info"], sample["summaries"][0]["text"]))
77+
self.summaries_1.append(make_text(sample["info"], sample["summaries"][1]["text"]))
8478
else:
85-
self.summaries_0.append(
86-
make_text(sample["info"], sample["summaries"][1]["text"])
87-
)
88-
self.summaries_1.append(
89-
make_text(sample["info"], sample["summaries"][0]["text"])
90-
)
79+
self.summaries_0.append(make_text(sample["info"], sample["summaries"][1]["text"]))
80+
self.summaries_1.append(make_text(sample["info"], sample["summaries"][0]["text"]))
9181
self.labels.append(0)
9282

9383
def __len__(self):
@@ -113,7 +103,7 @@ def __init__(self, train_path, tokenizer, split, max_length=1024):
113103
if split == "valid":
114104
df = df.sample(n=5000)
115105
self.summarizes = []
116-
for (i, row) in df.iterrows():
106+
for i, row in df.iterrows():
117107
self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}")
118108
self.tokenizer = tokenizer
119109
self.max_length = max_length
@@ -125,9 +115,7 @@ def __len__(self):
125115

126116
def __getitem__(self, idx):
127117
txt = self.summarizes[idx]
128-
encodings_dict = self.tokenizer(
129-
txt, truncation=True, max_length=self.max_length, padding="max_length"
130-
)
118+
encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
131119
input_ids = torch.tensor(encodings_dict["input_ids"])
132120
attn_masks = torch.tensor(encodings_dict["attention_mask"])
133121

examples/summarize_rlhf/trlx_gptj_text_summarization.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323

2424
if __name__ == "__main__":
25-
2625
# Load the pre-trained reward model
2726
rw_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
2827
rw_tokenizer.pad_token = rw_tokenizer.eos_token
@@ -38,9 +37,7 @@ def get_scores(samples: List[str]):
3837
batch_size = 2
3938
for i in range(0, len(samples), batch_size):
4039
sub_samples = samples[i : i + batch_size]
41-
sub_samples = [
42-
"<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples
43-
]
40+
sub_samples = ["<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples]
4441
encodings_dict = rw_tokenizer(
4542
sub_samples,
4643
truncation=True,
@@ -69,8 +66,7 @@ def get_prompt_dataset(prompts, max_length):
6966
tokenizer(
7067
prompts[i].split("TL;DR:")[0],
7168
truncation=True,
72-
max_length=max_length
73-
- 5, # to make sure "TL;DR" dont get truncated
69+
max_length=max_length - 5, # to make sure "TL;DR" dont get truncated
7470
)["input_ids"],
7571
skip_special_tokens=True,
7672
).strip()
@@ -84,25 +80,19 @@ def get_prompt_dataset(prompts, max_length):
8480

8581
def reward_fn(samples: List[str], **kwargs):
8682
original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples]
87-
original_samples = [
88-
text + post_summary_dict[text.strip()] for text in original_samples
89-
]
83+
original_samples = [text + post_summary_dict[text.strip()] for text in original_samples]
9084
original_scores = get_scores(original_samples)
9185
scores = get_scores(samples)
9286
norms_scores = scores - original_scores
9387
return norms_scores
9488

95-
config_path = pathlib.Path(__file__).parent.joinpath(
96-
"configs/ppo_config_summ_gptj.yml"
97-
)
89+
config_path = pathlib.Path(__file__).parent.joinpath("configs/ppo_config_summ_gptj.yml")
9890
config = TRLConfig.load_yaml(config_path)
9991

10092
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path)
10193
tokenizer.pad_token = tokenizer.eos_token
10294
tokenizer.padding_side = "left"
103-
max_length_input = (
104-
config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
105-
)
95+
max_length_input = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
10696

10797
dataset = load_dataset("CarperAI/openai_summarize_tldr")
10898

@@ -127,8 +117,6 @@ def reward_fn(samples: List[str], **kwargs):
127117
config.model.model_path,
128118
reward_fn=reward_fn,
129119
prompts=train_prompts,
130-
eval_prompts=val_prompts[
131-
0:1000
132-
], # sampling 1000 validation prompts for evaluation speed in training
120+
eval_prompts=val_prompts[0:1000], # sampling 1000 validation prompts for evaluation speed in training
133121
config=config,
134122
)

0 commit comments

Comments
 (0)