Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion torchrec/distributed/train_pipeline/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
# pyre-strict

import unittest
from typing import List
from typing import List, Optional
from unittest.mock import MagicMock

import parameterized

import torch
from torch import nn
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
from torchrec.distributed.train_pipeline.tracing import (
_get_leaf_module_names,
ArgInfo,
ArgInfoStepFactory,
CallArgs,
NodeArgsHelper,
Tracer,
)
from torchrec.distributed.types import NullShardedModuleContext, ShardedModule
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


Expand Down Expand Up @@ -110,3 +114,123 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:

# Weights is call_module node, so we should only find 2 args unmodified
self.assertEqual(num_found, len(kjt_args) - 1)


class DummyShardedModule(
ShardedModule[torch.Tensor, torch.Tensor, torch.Tensor, NullShardedModuleContext]
):
def __init__(self, alpha: float = 1) -> None:
super().__init__()
self.alpha = alpha

# pyre-ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.alpha * x

# pyre-ignore
def compute(self) -> torch.Tensor:
return torch.empty(0)

def create_context(self) -> NullShardedModuleContext:
return NullShardedModuleContext()

# pyre-ignore
def input_dist(self, ctx: NullShardedModuleContext):
pass

# pyre-ignore
def output_dist(self):
pass

# pyre-ignore
def unsharded_module_type(self):
pass


class DummyUmbrellaModule(nn.Module):
def __init__(self, m1: nn.Module, m2: nn.Module) -> None:
super().__init__()
self.m1 = m1
self.m2 = m2

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.m1(x) + self.m2(x)


class DummyNestedModule(nn.Module):
def __init__(self, layer: int = 0) -> None:
super().__init__()
self.layer = layer
self.inner: Optional[nn.Module] = (
DummyNestedModule(layer - 1) if layer > 0 else None
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
inner = 0 if self.inner is None else self.inner(x)
return inner + 10**self.layer


class TestFxTracer(unittest.TestCase):
@classmethod
def _generate_sharded_model(cls) -> nn.Module:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.nested = DummyNestedModule(3)
self.umbrella1 = DummyUmbrellaModule(
DummyNestedModule(2), DummyShardedModule()
)
self.umbrella2 = DummyUmbrellaModule(
DummyNestedModule(3), DummyShardedModule()
)
self.umbrella3 = DummyUmbrellaModule(
DummyNestedModule(4), DummyNestedModule(5)
)
self.umbrella4 = DummyUmbrellaModule(
DummyNestedModule(6), DummyNestedModule(7)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return (
# umbrella2 and umbrella4 are not directly
# called in this forward function
self.nested(x)
+ self.umbrella1(x)
+ self.umbrella2.m1(x)
+ self.umbrella2.m2(x)
+ self.umbrella3(x)
+ self.umbrella4.m1(x)
+ self.umbrella4.m2(x)
)

return MyModel()

def test_get_leaf_module_names(self) -> None:
model = self._generate_sharded_model()
leaf_modules = _get_leaf_module_names(model)
self.assertSetEqual(
set(leaf_modules), # umbrella1.m2 and umbrella2.m2 are `ShardedModule`s
{"nested", "umbrella1.m1", "umbrella2.m1", "umbrella3", "umbrella4"},
)

def test_top_level_tracer(self) -> None:
model = self._generate_sharded_model()
concrete_args = {}
tracer = Tracer(
leaf_modules=_get_leaf_module_names(model), extend_leaf_fqn=True
)
graph = tracer.trace(model, concrete_args=concrete_args)
targets = {node.target for node in graph.nodes if node.op == "call_module"}
self.assertSetEqual(
targets,
{
"nested",
"umbrella1.m1",
"umbrella1.m2",
"umbrella2.m1",
"umbrella2.m2",
"umbrella3",
"umbrella4.m1", # umbrella4 is not called in model.forward
"umbrella4.m2", # so umbrella4 is not a leaf module
},
)
52 changes: 51 additions & 1 deletion torchrec/distributed/train_pipeline/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,29 @@ def _get_leaf_module_names(model: torch.nn.Module) -> List[str]:
This is a shallow FX trace that only goes the minimum depth required to pipeline.
Any sub-module who does not contain a ShardedModule would be considered as a leaf
module unless explicitly tagged as `_is_pytorch_fx_traceable = True`.

disclaimer:
the algorithm is based on "named_modules()" API from torch.nn.Module, there
could be corner cases that a returned "leaf_module" is not actually called in the
forward pass. for example the umbrella_module would be considered as the top-level
leaf module but it actually won't appear in a fx-traced graph.

```
# the main_model's hierarchy looks like below:
main_model
- sharded_module
- umbrella_module
- actual_leaf_module_1
- actual_leaf_module_2

# and the main_model's forward is something like:
def forward(self, x1, x2, x3):
emb1 = self.sharded_module(x1)
emb2 = self.umbrella_module.actual_leaf_module_1(x2)
emb3 = self.umbrella_module.actual_leaf_module_2(x3)
return emb1 + emb2 + emb3
```

"""

def _get_leaf_module_names_helper(
Expand Down Expand Up @@ -573,9 +596,22 @@ class Tracer(torch.fx.Tracer):
# remove this line.
proxy_buffer_attributes = False

def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
def __init__(
self, leaf_modules: Optional[List[str]] = None, extend_leaf_fqn: bool = False
) -> None:
"""
Initializes the Tracer for FX tracing with custom leaf module handling.

Args:
leaf_modules: Optional list of fully qualified names (FQNs) of modules to treat
as leaf modules during tracing. If None, defaults to an empty list.
extend_leaf_fqn: If True, treats any module whose FQN starts with a leaf module
FQN as a leaf module (includes submodules). If False, only exact matches
are considered leaf modules. Defaults to False.
"""
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
self._extend_leaf_fqn = extend_leaf_fqn

def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if (
Expand All @@ -585,4 +621,18 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
or isinstance(m, FSDP2)
):
return True
if self._extend_leaf_fqn:
if self.is_extended_leaf_modules(m, module_qualified_name):
return True
return super().is_leaf_module(m, module_qualified_name)

def is_extended_leaf_modules(
self, m: torch.nn.Module, module_qualified_name: str
) -> bool:
for leaf_module in self._leaf_modules:
if module_qualified_name.startswith(leaf_module):
# in a corner case that the fqn == 'main_model.leaf_module.submod'
# we should consider this fqn also a leaf_module
if module_qualified_name[len(leaf_module)] == ".":
return True
return False
Loading