|
18 | 18 | from torchmetrics import RunningMean |
19 | 19 |
|
20 | 20 | from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable |
21 | | -from litgpt.args import EvalArgs, TrainArgs |
| 21 | +from litgpt.args import EvalArgs, LogArgs, TrainArgs |
22 | 22 | from litgpt.data import Alpaca, DataModule |
23 | 23 | from litgpt.generate.base import generate |
24 | 24 | from litgpt.prompts import save_prompt_style |
@@ -64,6 +64,7 @@ def setup( |
64 | 64 | max_seq_length=None, |
65 | 65 | ), |
66 | 66 | eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), |
| 67 | + log: LogArgs = LogArgs(), |
67 | 68 | optimizer: Union[str, Dict] = "AdamW", |
68 | 69 | logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", |
69 | 70 | seed: int = 1337, |
@@ -97,7 +98,13 @@ def setup( |
97 | 98 | config = Config.from_file(checkpoint_dir / "model_config.yaml") |
98 | 99 |
|
99 | 100 | precision = precision or get_default_supported_precision(training=True) |
100 | | - logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval) |
| 101 | + logger = choose_logger( |
| 102 | + logger_name, |
| 103 | + out_dir, |
| 104 | + name=f"finetune-{config.name}", |
| 105 | + log_interval=train.log_interval, |
| 106 | + log_args=dataclasses.asdict(log), |
| 107 | + ) |
101 | 108 |
|
102 | 109 | plugins = None |
103 | 110 | if quantize is not None and quantize.startswith("bnb."): |
|
0 commit comments