Skip to content

Commit 715c3bb

Browse files
authored
feat: Add support for multiple quantization modes in the UB communicators (#2043)
1 parent f98e305 commit 715c3bb

File tree

14 files changed

+216
-85
lines changed

14 files changed

+216
-85
lines changed

docs/api/pytorch.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pyTorch
4949

5050
.. autoapifunction:: transformer_engine.pytorch.moe_permute
5151

52-
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
52+
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
5353

5454
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
5555

@@ -62,3 +62,6 @@ pyTorch
6262
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
6363

6464
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
65+
66+
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
67+
:members: FP8, NONE

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
263263
te.module.base.initialize_ub(
264264
[batched_size, hidden_size],
265265
tp_size,
266-
use_fp8=opts.fp8,
266+
quantization_modes=[
267+
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
268+
],
267269
dtype=torch.bfloat16,
268270
bootstrap_backend=opts.bootstrap_backend,
269271
)

tests/pytorch/distributed/run_layer_with_overlap.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import warnings
1313
import pprint
1414
import yaml
15+
from contextlib import nullcontext
16+
from functools import partial
1517

1618
import torch
1719
import torch.distributed as dist
@@ -35,9 +37,10 @@ def __init__(self, module, num_layers, *args, **kwargs):
3537
self.num_layers = num_layers
3638
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
3739

38-
def forward(self, x):
39-
for layer in self.layers:
40-
x = layer(x)
40+
def forward(self, x, layer_contexts):
41+
for layer, context in zip(self.layers, layer_contexts):
42+
with context():
43+
x = layer(x)
4144
return x
4245

4346

@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
237240
default=False,
238241
help="Print out additional debug information.",
239242
)
243+
parser.add_argument(
244+
"--first-last-layers-bf16",
245+
action="store_true",
246+
default=False,
247+
help="Use bf16 for first and last N layers.",
248+
)
249+
parser.add_argument(
250+
"--num-layers-at-start-in-bf16",
251+
type=int,
252+
default=0,
253+
help="Number of layers at the start to run in bf16.",
254+
)
255+
parser.add_argument(
256+
"--num-layers-at-end-in-bf16",
257+
type=int,
258+
default=0,
259+
help="Number of layers at the end to run in bf16.",
260+
)
240261
args = parser.parse_args(argv, namespace)
241262

242263
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
243264
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
244265
args.use_cuda_graphs = False
245266

267+
if not args.first_last_layers_bf16 and (
268+
args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0
269+
):
270+
warnings.warn(
271+
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
272+
" first-last-layers-bf16 is enabled!"
273+
)
274+
args.num_layers_at_start_in_bf16 = 0
275+
args.num_layers_at_end_in_bf16 = 0
276+
277+
if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers:
278+
raise ValueError(
279+
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
280+
" num-layers!"
281+
)
282+
246283
return args
247284

248285

@@ -381,10 +418,17 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
381418
"qkv_dgrad": {"method": "ring_exchange"},
382419
"fc1_dgrad": {"method": "ring_exchange"},
383420
}
421+
422+
quantization_modes = [
423+
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
424+
]
425+
if opts.first_last_layers_bf16 and opts.fp8:
426+
quantization_modes.append(UserBufferQuantizationMode.NONE)
427+
384428
te.module.base.initialize_ub(
385429
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
386430
opts.tp,
387-
use_fp8=opts.fp8,
431+
quantization_modes=quantization_modes,
388432
dtype=torch.bfloat16,
389433
bootstrap_backend=opts.bootstrap_backend,
390434
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
@@ -423,6 +467,16 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
423467
elif opts.quantization == "mxfp8":
424468
fp8_recipe = MXFP8BlockScaling()
425469

470+
layer_contexts = [
471+
(
472+
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
473+
if opts.num_layers_at_start_in_bf16 <= i
474+
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
475+
else nullcontext
476+
)
477+
for i in range(opts.num_layers)
478+
]
479+
426480
# Prepare random input tensors
427481
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
428482
test_x.retain_grad()
@@ -435,14 +489,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
435489
# Execute fwd/bwd and collect tensors to test
436490
def run_fwd_bwd(model, x):
437491
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
438-
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
439-
y = model(x)
440-
if isinstance(y, tuple):
441-
out, *_ = y
442-
else:
443-
out = y
444-
loss = out.sum()
445-
loss.backward()
492+
y = model(x, layer_contexts)
493+
if isinstance(y, tuple):
494+
out, *_ = y
495+
else:
496+
out = y
497+
loss = out.sum()
498+
loss.backward()
446499
return out
447500

448501
torch_rng_state = torch.get_rng_state()

tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,13 @@ def main() -> None:
506506
model_config.num_heads * model_config.head_dim,
507507
],
508508
torch.distributed.get_world_size(group),
509-
use_fp8=model_config.quantization is not None,
509+
quantization_modes=[
510+
(
511+
UserBufferQuantizationMode.FP8
512+
if model_config.quantization is not None
513+
else UserBufferQuantizationMode.NONE
514+
)
515+
],
510516
dtype=model_config.dtype,
511517
bootstrap_backend=bootstrap_backend,
512518
ub_cfgs=userbuffer_configs,

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) {
511511
}
512512

513513
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
514-
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
514+
if (comm->free_region >= NVTE_MAX_REGIONS) return -1;
515515
int hndl = comm->free_region;
516516
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
517517
size_t aligned_size = bytes;

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
2828
using ExtBarrierOp = std::function<void(ExtComm)>;
2929

30-
#define NVTE_MAX_REGIONS 16
30+
#define NVTE_MAX_REGIONS 32
3131
#define NVTE_MAX_SMS 32
3232
#define NVTE_MAX_OPS 32
3333
#define NVTE_MAX_PEERS 8192

transformer_engine/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def torch_version() -> tuple[int, ...]:
3333
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
3434
from transformer_engine.pytorch.module import initialize_ub
3535
from transformer_engine.pytorch.module import destroy_ub
36+
from transformer_engine.pytorch.module import UserBufferQuantizationMode
3637
from transformer_engine.pytorch.attention import DotProductAttention
3738
from transformer_engine.pytorch.attention import MultiheadAttention
3839
from transformer_engine.pytorch.attention import InferenceParams

transformer_engine/pytorch/module/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
from .rmsnorm import RMSNorm
1212
from .fp8_padding import Fp8Padding
1313
from .fp8_unpadding import Fp8Unpadding
14-
from .base import initialize_ub, destroy_ub
14+
from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode

0 commit comments

Comments
 (0)