Skip to content

Commit

Permalink
add kto to webui
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 20, 2024
1 parent d52fae2 commit 9b0f4d7
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 38 deletions.
21 changes: 16 additions & 5 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:

with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
with gr.Column():
ppo_score_norm = gr.Checkbox()
ppo_whiten_rewards = gr.Checkbox()

input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
dict(
rlhf_tab=rlhf_tab,
pref_beta=pref_beta,
pref_ftx=pref_ftx,
pref_loss=pref_loss,
reward_model=reward_model,
ppo_score_norm=ppo_score_norm,
ppo_whiten_rewards=ppo_whiten_rewards,
)
)

with gr.Accordion(open=False) as galore_tab:
Expand Down
72 changes: 50 additions & 22 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,52 +774,52 @@
"label": "RLHF 参数设置",
},
},
"dpo_beta": {
"pref_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss.",
"label": "Beta value",
"info": "Value of the beta parameter in the loss.",
},
"ru": {
"label": "DPO бета",
"info": "Значение параметра бета в функции потерь DPO.",
"label": "Бета значение",
"info": "Значение параметра бета в функции потерь.",
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。",
"label": "Beta 参数",
"info": "损失函数中 beta 超参数大小。",
},
},
"dpo_ftx": {
"pref_ftx": {
"en": {
"label": "DPO-ftx weight",
"info": "The weight of SFT loss in the DPO-ftx.",
"label": "Ftx gamma",
"info": "The weight of SFT loss in the final loss.",
},
"ru": {
"label": "Вес DPO-ftx",
"info": "Вес функции потерь SFT в DPO-ftx.",
"label": "Ftx гамма",
"info": "Вес потери SFT в итоговой потере.",
},
"zh": {
"label": "DPO-ftx 权重",
"info": "DPO-ftx 中 SFT 损失的权重大小。",
"label": "Ftx gamma",
"info": "损失函数中 SFT 损失的权重大小。",
},
},
"orpo_beta": {
"pref_loss": {
"en": {
"label": "ORPO beta",
"info": "Value of the beta parameter in the ORPO loss.",
"label": "Loss type",
"info": "The type of the loss function.",
},
"ru": {
"label": "ORPO бета",
"info": "Значение параметра бета в функции потерь ORPO.",
"label": "Тип потерь",
"info": "Тип функции потерь.",
},
"zh": {
"label": "ORPO beta 参数",
"info": "ORPO 损失函数中 beta 超参数大小。",
"label": "损失类型",
"info": "损失函数的类型。",
},
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Adapter of the reward model for PPO training.",
"info": "Adapter of the reward model in PPO training.",
},
"ru": {
"label": "Модель вознаграждения",
Expand All @@ -830,6 +830,34 @@
"info": "PPO 训练中奖励模型的适配器路径。",
},
},
"ppo_score_norm": {
"en": {
"label": "Score norm",
"info": "Normalizing scores in PPO training.",
},
"ru": {
"label": "Норма оценок",
"info": "Нормализация оценок в тренировке PPO.",
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中归一化奖励分数。",
},
},
"ppo_whiten_rewards": {
"en": {
"label": "Whiten rewards",
"info": "Whiten the rewards in PPO training.",
},
"ru": {
"label": "Белые вознаграждения",
"info": "Осветлите вознаграждения в обучении PPO.",
},
"zh": {
"label": "白化奖励",
"info": "PPO 训练中将奖励分数做白化处理。",
},
},
"galore_tab": {
"en": {
"label": "GaLore configurations",
Expand Down
36 changes: 25 additions & 11 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
plot_loss=True,
)

# freeze config
if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
elif args["finetuning_type"] == "lora":

# lora config
if args["finetuning_type"] == "lora":
args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout")
Expand All @@ -163,6 +166,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
if args["use_llama_pro"]:
args["num_layer_trainable"] = get("train.num_layer_trainable")

# rlhf config
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
[
Expand All @@ -171,31 +175,41 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0
args["top_p"] = 0.9
elif args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
args["dpo_ftx"] = get("train.dpo_ftx")
args["dpo_beta"] = get("train.pref_beta")
args["dpo_ftx"] = get("train.pref_ftx")
args["dpo_loss"] = get("train.pref_loss")
elif args["stage"] == "kto":
args["kto_beta"] = get("train.pref_beta")
args["kto_ftx"] = get("train.pref_ftx")
elif args["stage"] == "orpo":
args["orpo_beta"] = get("train.orpo_beta")

if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
args["orpo_beta"] = get("train.pref_beta")

# galore config
if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank")
args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")

# badam config
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")

# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]

return args

def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 9b0f4d7

Please sign in to comment.