Skip to content

Commit d47996d

Browse files
glerzingjon-towmaxreciprocate
authored
peft to opendelta migration (#434) + memory optimization (#320) (#486)
* Migrate to peft from opendelta for parameter efficient tuning methods (#434) + Collapse reference+learner hydra heads when using LoRa (#320) * fix from_config * Review corrections * ILQL generate when temperature is 0. * revert: guard against experimental 8-bit loading support * format: run `black` --------- Co-authored-by: jon-tow <[email protected]> Co-authored-by: maxreciprocate <[email protected]>
1 parent 171357b commit d47996d

15 files changed

+1046
-404
lines changed

examples/ppo_sentiments_peft.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
2+
# with a sentiment reward function
3+
import json
4+
import os
5+
import sys
6+
from typing import List
7+
8+
import torch
9+
from datasets import load_dataset
10+
from peft import LoraConfig
11+
from peft.utils.config import TaskType
12+
from transformers import pipeline
13+
14+
import trlx
15+
from trlx.data.default_configs import TRLConfig, default_ppo_config
16+
17+
18+
def get_positive_score(scores):
19+
"Extract value associated with a positive sentiment from pipeline's output"
20+
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
21+
22+
23+
def main(hparams={}):
24+
# Merge sweep config with default config if given
25+
config = TRLConfig.update(default_ppo_config().to_dict(), hparams)
26+
27+
if torch.cuda.is_available():
28+
device = int(os.environ.get("LOCAL_RANK", 0))
29+
else:
30+
device = -1
31+
32+
sentiment_fn = pipeline(
33+
"sentiment-analysis",
34+
"lvwerra/distilbert-imdb",
35+
top_k=2,
36+
truncation=True,
37+
batch_size=256,
38+
device=device,
39+
)
40+
41+
# Just insert your peft config here (the type must be an instance of peft.PeftConfig or a dict).
42+
config.model.peft_config = LoraConfig(
43+
r=8,
44+
task_type=TaskType.CAUSAL_LM,
45+
lora_alpha=32,
46+
lora_dropout=0.1,
47+
)
48+
49+
def reward_fn(samples: List[str], **kwargs) -> List[float]:
50+
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
51+
return sentiments
52+
53+
# Take few words off of movies reviews as prompts
54+
imdb = load_dataset("imdb", split="train+test")
55+
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
56+
57+
trlx.train(
58+
reward_fn=reward_fn,
59+
prompts=prompts,
60+
eval_prompts=["I don't know much about Hungarian underground"] * 256,
61+
config=config,
62+
)
63+
64+
65+
if __name__ == "__main__":
66+
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
67+
main(hparams)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ numpy==1.24.3
4343
packaging==23.1
4444
pandas==2.0.1
4545
pathtools==0.1.2
46+
peft==0.3.0
4647
pkgutil_resolve_name==1.3.10
4748
platformdirs==3.5.0
4849
protobuf==4.22.3

0 commit comments

Comments
 (0)