diff --git a/torchtitan/models/llama/pipeline_llama.py b/torchtitan/models/llama/pipeline_llama.py index bdd62813..6a3622ba 100644 --- a/torchtitan/models/llama/pipeline_llama.py +++ b/torchtitan/models/llama/pipeline_llama.py @@ -13,7 +13,11 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import _PipelineSchedule, get_schedule_class, ScheduleZBVZeroBubble +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + ScheduleZBVZeroBubble, +) from torchtitan.config_manager import JobConfig from torchtitan.logging import logger diff --git a/torchtitan/train_spec.py b/torchtitan/train_spec.py index 222ff97f..f76b1a8d 100644 --- a/torchtitan/train_spec.py +++ b/torchtitan/train_spec.py @@ -36,7 +36,8 @@ class ModelProtocol(Protocol): """ @staticmethod - def from_model_args(args: BaseModelArgs) -> nn.Module: ... + def from_model_args(args: BaseModelArgs) -> nn.Module: + ... OptimizersBuilder: TypeAlias = Callable[