Skip to content

Commit

Permalink
Major updates
Browse files Browse the repository at this point in the history
  • Loading branch information
artemisp committed Mar 2, 2024
1 parent aa5c29f commit 0805275
Show file tree
Hide file tree
Showing 23 changed files with 2,011 additions and 1,273 deletions.
11 changes: 7 additions & 4 deletions .env_template
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Location to download data
DATA_DIR=<pwd>/data
# Location to project directory
PROJ_DIR=<pwd>
# WANDB Parameters and Info
WANDB_PROJ_NAME=<PROJ_NAME>
WANDB_DIR=<pwd>/wandb
WANDB_RESUME=allow
# Cache directory (models, datasets)
CACHE_DIR=<pwd>/.cache
# Checkpoint directory
CHECKPOINT_DIR=<pwd>/checkpoints
# Prediction output directory
PREDICTION_DIR=<pwd>/predictions
# Checkpoints and Results directory
OUTPUT_DIR=<pwd>/output
# Huggingface access token
# see: https://huggingface.co/docs/hub/en/security-tokens
HF_ACCESS_TOKEN=<huggingface_token>
683 changes: 335 additions & 348 deletions README.md

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
transformers==4.28.1
datasets==2.11.0
tokenizers==0.12.0
transformers
datasets
tokenizers
evaluate
pytorch-lightning
mmcv-lite
pandas
python-dotenv
wandb
wandb
packaging
peft
accelerate
bitsandbytes
7 changes: 7 additions & 0 deletions src/common/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def trim_prefix(state_dict):
state_dict = {k:v for k, v in state_dict.items() if 'learned_embedding' in k}
return state_dict

def trim_lora(state_dict):
state_dict = {k:v for k, v in state_dict.items() if 'lora' in k.lower()}
return state_dict
99 changes: 0 additions & 99 deletions src/configs/base.py

This file was deleted.

188 changes: 188 additions & 0 deletions src/configs/base_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import os
proj_dir=os.getcwd()

seed=42
debug=True
strategy='ddp'

prefix_tuning=False
prefix_tokens=30

# Output directory
output_dir = f'{os.getenv("OUTPUT_DIR", f"{proj_dir}/output")}/llama2/lora/natural_instructions_200k'
resume_from_checkpoint = None
metrics = ['bleu']

raw_data = "Muennighoff/natural-instructions"


preprocessing_kwargs = {
"remove_html": False,
"pad_punctuation": False,
"drop_tables": False,
"column_dict": {"inputs": ["definition", "inputs"], "target": "targets"},
"input_template": "[INST] {} {} [/INST]",
"target_template": "{}",
"concat_input_output": True,
"keep_columns": ["definition", "input", "target", "context_aware_embeds"],
}


tokenization_kwargs = {
"tokenizer_name": 'meta-llama/Llama-2-7b-hf',
"max_input_length": 1024,
"max_target_length": 1024,
"padding": "max_length",
"truncation": True,
"concat_input_output": True,
"prefix_tuning": prefix_tuning,
"n_prefix_tokens": prefix_tokens,
"decoder_prefix": False,
"pad_token": 'unk_token'
}

# Datamodule Arguments
datamodule_kwargs = {
"debug": debug,
"strategy": strategy,
"raw_data": raw_data,
"deduplicate_columns": ["id"],
"load_from_cache_file": False,
"num_workers": 12,
"batch_size": 2,
"shots": 10000,
"dev_from_train": -1, ## set to -1 if use dev file for validation, else subsample from train
"overfit": False,
"dev_size": 1024,
"tiny": False,
"tiny_size": 1024,
"filter_long_sequences": True,
"preprocessing_kwargs": preprocessing_kwargs,
"tokenization_kwargs": tokenization_kwargs,
"batch_tokenize": True,
"predict_split": 'dev',

}


## logger arguments
logger_type='wandb'
logger_kwargs = {
'name': 'llama2/lora/natural_instructions_200k',
'save_dir': os.getenv("OUTPUT_DIR", f"{proj_dir}/wandb_logs"),
'project': os.getenv("WANDB_PROJ_NAME", f"test"),
'log_model': False,
'resume': os.getenv("WANDB_RESUME", "allow"),
}

optimizer_config = {
"lr": 1e-4,
"eps": 1e-8,
"weight_decay": 1e-4,
"scheduler": "CosineAnnealingLR",

}

lora_config = {
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"bias": "none",
"target_modules": ['q_proj','v_proj', 'k_proj', 'lm_head'],
"task_type": "CAUSAL_LM",
}


quantization_config = {
"load_in_8bit":True,
"bnb_8bit_use_double_quant":True,
"bnb_8bit_quant_type":"nf8",
"bnb_8bit_compute_dtype": "bfloat16"
}

generation_kwargs= {
"max_new_tokens": 30,
"min_new_tokens": 1,
"num_return_sequences": 1,
"do_sample": False,
}

# Model Arguments
module_kwargs = {
"model_name": 'meta-llama/Llama-2-7b-hf',
"optimizer": 'AdamW',
"auto_model_class": "AutoModelForCausalLM",
"prefix_tuning": prefix_tuning,
"n_prefix_tokens": prefix_tokens,
"initialize_from_vocab": False,

"optimizer_type": "AdamW",
"optimizer_config": optimizer_config,
"gradient_checkpointing": True,
"quantization_precision": 8,
"precision": "bf16",
"tokenization_kwargs": tokenization_kwargs,

"lora": True,
"lora_config": lora_config,
"quantization": True,
"quantization_config": quantization_config,

"generation_kwargs": generation_kwargs,

"freeze_encoder": False,
"freeze_encoder_layers": [],
"freeze_decoder": False,
"freeze_decoder_layers": [],
"keep_in_fp32_modules": [],
"resume_from_checkpoint": resume_from_checkpoint,
"postproc_fn": "identity",
}


# Callbacks
checkpoint_callback=True
checkpoint_callback_kwargs = {
"dirpath": output_dir,
"verbose": True,
"monitor": "val_loss",
"mode": "min",
"save_last": True,
"save_top_k": 1,
"every_n_train_steps": 10,
"save_on_train_epoch_end": False
}

# Trainer Arguments
accelerator='auto'
devices="auto"
num_nodes=1
precision="bf16-mixed"
fast_dev_run=False
max_epochs=1
min_epochs=None
max_steps=100000
min_steps=1000
max_time=None
limit_train_batches=None
limit_val_batches=None
limit_test_batches=None
limit_predict_batches=None
overfit_batches=0.0
val_check_interval=.1
check_val_every_n_epoch=1
num_sanity_val_steps=0
log_every_n_steps=50
enable_progress_bar=True
enable_model_summary=True
accumulate_grad_batches=4
gradient_clip_val=0.3
gradient_clip_algorithm='norm'
deterministic=None
benchmark=None
inference_mode=True
profiler=None
detect_anomaly=False
barebones=False
sync_batchnorm=strategy in ['ddp', 'fsdp','fsdp_native', 'ddp_find_unused_parameters_true']
reload_dataloaders_every_n_epochs=0
Loading

0 comments on commit 0805275

Please sign in to comment.