diff --git a/scripts/sft.py b/scripts/sft.py index 843d6fadd..baa52e156 100644 --- a/scripts/sft.py +++ b/scripts/sft.py @@ -66,11 +66,11 @@ def main(args): parser.add_argument("--name-to-save", "-n", type=str, default="Qwen3-1.7B-Wordle") parser.add_argument("--max-length", "-l", type=int, default=8192) parser.add_argument("--per-device-train-batch-size", "-b", type=int, default=2) - parser.add_argument("--gradient-accumulation-steps", "-g", type=int, default=1) + parser.add_argument("--gradient-accumulation-steps", "-G", type=int, default=1) parser.add_argument("--learning-rate", "-r", type=float, default=2e-5) parser.add_argument("--num-train-epochs", "-e", type=int, default=3) parser.add_argument("--weight-decay", "-w", type=float, default=0.01) parser.add_argument("--max-grad-norm", "-g", type=float, default=0.1) - parser.add_argument("--push-to-hub", "-p", type=bool, default=True) + parser.add_argument("--push-to-hub", "-p", action=argparse.BooleanOptionalAction, default=True) args = parser.parse_args() main(args)