Skip to content

Commit

Permalink
Update on "enable TritonFusedRMSNorm with local_map annotation"
Browse files Browse the repository at this point in the history
**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
  • Loading branch information
XilunWu committed Jun 12, 2024
2 parents 71c725d + 2ddaa8b commit 8c24711
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 29 deletions.
6 changes: 5 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
We want to make contributing to this project as easy and transparent as
possible.

## Setup
```
pip install -r dev-requirements.txt
```

## Pull Requests
We actively welcome your pull requests.
Expand All @@ -10,7 +14,7 @@ We actively welcome your pull requests.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
5. Make sure your code lints (`pre-commit run --all-files`).
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[![4 GPU Integration Test](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_4gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_4gpu.yaml)
[![4 GPU Integration Test](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_4gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_4gpu.yaml?query=branch%3Amain)
[![8 GPU Integration Test](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu.yaml?query=branch%3Amain)

# torchtitan

Expand Down
6 changes: 3 additions & 3 deletions test/test_fused_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def test_fused_rms_norm(self):
x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1)
w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate

dx = distribute_tensor(x, mesh, [Shard(1)])
dw = distribute_tensor(w, mesh, [Replicate()])
dist_x = distribute_tensor(x, mesh, [Shard(1)])
dist_w = distribute_tensor(w, mesh, [Replicate()])

comm_mode = CommDebugMode()
# fused rmsnorm
with comm_mode:
out = fused_rms_norm_fn(dx, dw)
out = fused_rms_norm_fn(dist_x, dist_w)

self.assertEqual(comm_mode.get_total_counts(), 0)

Expand Down
16 changes: 13 additions & 3 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_split_mode tracer",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
Expand Down Expand Up @@ -161,8 +163,17 @@ def build_test_list():
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"Eager mode 2DParallel",
"eager_2d",
"Eager mode 2DParallel with rmsnorm",
"eager_2d_rmsnorm",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
],
],
"Eager mode 2DParallel with fused_rmsnorm",
"eager_2d_fused_rmsnorm",
),
OverrideDefinitions(
[
Expand Down Expand Up @@ -261,7 +272,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
logger.info(result.stdout)

for override_arg in test_flavor.override_args:

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
cmd += " " + dump_folder_arg
if override_arg:
Expand Down
28 changes: 11 additions & 17 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
SplitPoint,
)
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand Down Expand Up @@ -236,12 +231,11 @@ def pipeline_llama_manual(
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

model.to_empty(device=device)
stage = ManualPipelineStage(
stage = PipelineStage(
model,
pp_rank,
pp_size,
device,
microbatches,
input_args=input.chunk(microbatches)[0],
output_args=output.chunk(microbatches)[0],
group=pp_mesh.get_group("pp"),
Expand All @@ -267,23 +261,23 @@ def pipeline_llama_tracer(

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
(input,) = _llama_trace_input(job_config, model_config, device=device)
stage_idx = pp_rank
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
layer_name: SplitPoint.BEGINNING
for layer_name in job_config.experimental.pipeline_parallel_split_points
}

pipe = pipeline(
model,
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
example_args=_llama_trace_input(job_config, model_config),
mb_args=(input.chunk(microbatches)[0],),
split_spec=split_spec,
)
model = pipe.get_stage_module(stage_idx)
stage = PipelineStage(
pipe,
stage_index=stage_idx,
stage = pipe.build_stage(
stage_idx,
device=device,
group=pp_mesh.get_group(),
)
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn):
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
n_microbatches = job_config.experimental.pipeline_parallel_microbatches
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree

return schedule_class(
stage,
n_microbatches=stage.chunks,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
)
6 changes: 3 additions & 3 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int
else:
metrics_log_rank = 0

return metrics_log_rank


def set_pg_timeouts(timeout, world_mesh):
"""
Expand All @@ -70,9 +72,7 @@ def set_pg_timeouts(timeout, world_mesh):
torch.distributed.barrier()
torch.cuda.synchronize()

groups = (
[world_mesh.get_group()] if world_mesh.ndim == 1 else world_mesh.get_group()
)
groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]

# None represents the 'default' PG, not part of the mesh
groups.append(None)
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ def loss_fn(pred, labels):
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

# clip gradients
Expand Down

0 comments on commit 8c24711

Please sign in to comment.