|
| 1 | +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset |
| 2 | +# with a sentiment reward function |
| 3 | +import json |
| 4 | +import os |
| 5 | +import sys |
| 6 | +from typing import List |
| 7 | + |
| 8 | +import torch |
| 9 | +from datasets import load_dataset |
| 10 | +from peft import LoraConfig |
| 11 | +from peft.utils.config import TaskType |
| 12 | +from transformers import pipeline |
| 13 | + |
| 14 | +import trlx |
| 15 | +from trlx.data.default_configs import TRLConfig, default_ppo_config |
| 16 | + |
| 17 | + |
| 18 | +def get_positive_score(scores): |
| 19 | + "Extract value associated with a positive sentiment from pipeline's output" |
| 20 | + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] |
| 21 | + |
| 22 | + |
| 23 | +def main(hparams={}): |
| 24 | + # Merge sweep config with default config if given |
| 25 | + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) |
| 26 | + |
| 27 | + if torch.cuda.is_available(): |
| 28 | + device = int(os.environ.get("LOCAL_RANK", 0)) |
| 29 | + else: |
| 30 | + device = -1 |
| 31 | + |
| 32 | + sentiment_fn = pipeline( |
| 33 | + "sentiment-analysis", |
| 34 | + "lvwerra/distilbert-imdb", |
| 35 | + top_k=2, |
| 36 | + truncation=True, |
| 37 | + batch_size=256, |
| 38 | + device=device, |
| 39 | + ) |
| 40 | + |
| 41 | + # Just insert your peft config here (the type must be an instance of peft.PeftConfig or a dict). |
| 42 | + config.model.peft_config = LoraConfig( |
| 43 | + r=8, |
| 44 | + task_type=TaskType.CAUSAL_LM, |
| 45 | + lora_alpha=32, |
| 46 | + lora_dropout=0.1, |
| 47 | + ) |
| 48 | + |
| 49 | + def reward_fn(samples: List[str], **kwargs) -> List[float]: |
| 50 | + sentiments = list(map(get_positive_score, sentiment_fn(samples))) |
| 51 | + return sentiments |
| 52 | + |
| 53 | + # Take few words off of movies reviews as prompts |
| 54 | + imdb = load_dataset("imdb", split="train+test") |
| 55 | + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] |
| 56 | + |
| 57 | + trlx.train( |
| 58 | + reward_fn=reward_fn, |
| 59 | + prompts=prompts, |
| 60 | + eval_prompts=["I don't know much about Hungarian underground"] * 256, |
| 61 | + config=config, |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +if __name__ == "__main__": |
| 66 | + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) |
| 67 | + main(hparams) |
0 commit comments