diff --git a/finetune/adapter.py b/finetune/adapter.py index f4bf266e..3ae26eaa 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -36,7 +36,7 @@ save_interval = 1000 eval_iters = 100 log_interval = 1 -devices = 1 +devices = torch.cuda.device_count() # Hyperparameters learning_rate = 9e-3 diff --git a/finetune/full.py b/finetune/full.py index 58932967..3ab673ce 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -31,7 +31,7 @@ save_interval = 1000 eval_iters = 100 log_interval = 100 -devices = 4 +devices = torch.cuda.device_count() # Hyperparameters learning_rate = 3e-5 diff --git a/finetune/lora.py b/finetune/lora.py index e00e438a..b370e836 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -50,7 +50,8 @@ def main( out_dir: str = "out/lora/alpaca", ): - fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true") + devices = torch.cuda.device_count() + fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-true") fabric.launch() fabric.seed_everything(1337 + fabric.global_rank)