Skip to content

Commit 3b4366b

Browse files
authored
Fix CI failures for UB overlap changes (#2149)
Signed-off-by: djns99 <[email protected]>
1 parent 67fcc15 commit 3b4366b

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,11 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
264264
[batched_size, hidden_size],
265265
tp_size,
266266
quantization_modes=[
267-
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
267+
(
268+
te.module.base.UserBufferQuantizationMode.FP8
269+
if opts.fp8
270+
else te.module.base.UserBufferQuantizationMode.NONE
271+
)
268272
],
269273
dtype=torch.bfloat16,
270274
bootstrap_backend=opts.bootstrap_backend,

tests/pytorch/distributed/run_layer_with_overlap.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,14 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
420420
}
421421

422422
quantization_modes = [
423-
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
423+
(
424+
te.module.base.UserBufferQuantizationMode.FP8
425+
if opts.fp8
426+
else te.module.base.UserBufferQuantizationMode.NONE
427+
)
424428
]
425429
if opts.first_last_layers_bf16 and opts.fp8:
426-
quantization_modes.append(UserBufferQuantizationMode.NONE)
430+
quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE)
427431

428432
te.module.base.initialize_ub(
429433
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],

tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,9 @@ def main() -> None:
508508
torch.distributed.get_world_size(group),
509509
quantization_modes=[
510510
(
511-
UserBufferQuantizationMode.FP8
511+
te.module.base.UserBufferQuantizationMode.FP8
512512
if model_config.quantization is not None
513-
else UserBufferQuantizationMode.NONE
513+
else te.module.base.UserBufferQuantizationMode.NONE
514514
)
515515
],
516516
dtype=model_config.dtype,

transformer_engine/pytorch/module/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def add_ub(
473473
fp8_buf = (name in layers_all_gather_overlap) or (
474474
user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
475475
)
476-
ub_cfg.update(ub_cfgs[name])
476+
ub_cfg.update(user_ub_cfg[name])
477477
ub_cfg["fp8_buf"] = fp8_buf
478478
add_ub(name, quantization_mode, **ub_cfg)
479479

0 commit comments

Comments
 (0)