Skip to content

Commit 440b973

Browse files
committed
update requirements.txt
1 parent b618ddb commit 440b973

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

baselines/src/finetune.py

-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed, Trainer
33

4-
import wandb
54
import transformers
65
import os
76
from peft import LoraConfig, get_peft_model
@@ -10,8 +9,6 @@
109
from src.utils import get_model_identifiers_from_yaml, find_all_linear_names
1110
from src.dataset import QADataset, DefaultDataset
1211

13-
os.environ['WANDB_MODE'] = 'dryrun'
14-
1512
def finetune(cfg):
1613
if os.environ.get('LOCAL_RANK') is not None:
1714
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
@@ -28,12 +25,6 @@ def finetune(cfg):
2825
model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
2926
model_id = model_cfg["hf_key"]
3027

31-
wandb.init(project='finetune', config={
32-
"learning_rate": cfg.lr,
33-
"epochs": cfg.num_epochs,
34-
"batch_size": batch_size * gradient_accumulation_steps * num_devices,
35-
}, name=f'finetune-lr{cfg.lr}-epoch{cfg.num_epochs}')
36-
3728
Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
3829
# save the cfg file
3930
#if master process

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ scipy==1.14.1
1616
ninja==1.11.1.2
1717
zhipuai==2.1.5.20241203
1818
openai==1.55.3
19-
vllm==0.6.5
19+
vllm==0.6.5
20+
scikit-learn==1.6.1

0 commit comments

Comments
 (0)