Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Feb 12, 2025
1 parent c131309 commit 2f4d1ce
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 5 additions & 1 deletion torchtitan/models/llama/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import _PipelineSchedule, get_schedule_class, ScheduleZBVZeroBubble
from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
get_schedule_class,
ScheduleZBVZeroBubble,
)

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class ModelProtocol(Protocol):
"""

@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module: ...
def from_model_args(args: BaseModelArgs) -> nn.Module:
...


OptimizersBuilder: TypeAlias = Callable[
Expand Down

0 comments on commit 2f4d1ce

Please sign in to comment.