From 2f4d1ce250f9ff294a0175a0a2f33c5c3a8a40c9 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Feb 2025 17:29:23 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- torchtitan/models/llama/pipeline_llama.py | 6 +++++- torchtitan/train_spec.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/pipeline_llama.py b/torchtitan/models/llama/pipeline_llama.py index bdd62813..6a3622ba 100644 --- a/torchtitan/models/llama/pipeline_llama.py +++ b/torchtitan/models/llama/pipeline_llama.py @@ -13,7 +13,11 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import _PipelineSchedule, get_schedule_class, ScheduleZBVZeroBubble +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + ScheduleZBVZeroBubble, +) from torchtitan.config_manager import JobConfig from torchtitan.logging import logger diff --git a/torchtitan/train_spec.py b/torchtitan/train_spec.py index 222ff97f..f76b1a8d 100644 --- a/torchtitan/train_spec.py +++ b/torchtitan/train_spec.py @@ -36,7 +36,8 @@ class ModelProtocol(Protocol): """ @staticmethod - def from_model_args(args: BaseModelArgs) -> nn.Module: ... + def from_model_args(args: BaseModelArgs) -> nn.Module: + ... OptimizersBuilder: TypeAlias = Callable[