Skip to content

Commit 30c0120

Browse files
authored
[PyTorch] Fix small errors (#2396)
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent e122173 commit 30c0120

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

tests/pytorch/distributed/run_gemm_with_overlap.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@
2424
MXFP8Quantizer,
2525
)
2626
import transformer_engine.pytorch.cpp_extensions as tex
27-
from transformer_engine.pytorch.module.base import (
28-
fill_userbuffers_buffer_for_all_gather,
29-
get_cublas_workspace_size_bytes,
30-
)
27+
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
28+
from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
3129

3230
warnings.filterwarnings("ignore", category=DeprecationWarning)
3331
warnings.filterwarnings("ignore", category=FutureWarning)
@@ -417,10 +415,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
417415
std=opts.std,
418416
)
419417

420-
# Allocate cuBLAS workspace
421-
workspace_size = 3 * get_cublas_workspace_size_bytes()
422-
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
423-
424418
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
425419
if opts.bulk_overlap:
426420
ker_g = torch.transpose(kernel_t, 0, 1)
@@ -617,7 +611,6 @@ def _fp8_gemm():
617611
return tex.general_gemm(
618612
kernel_t_fp8,
619613
gemm_inp,
620-
workspace,
621614
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
622615
quantization_params=out_quantizer,
623616
use_split_accumulator=te.module.base._2X_ACC_FPROP,
@@ -635,7 +628,6 @@ def _fp8_gemm2(gemm1_out):
635628
return tex.general_gemm(
636629
kernel2_t_fp8,
637630
gemm2_inp,
638-
workspace,
639631
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
640632
quantization_params=out2_quantizer,
641633
use_split_accumulator=te.module.base._2X_ACC_FPROP,
@@ -648,7 +640,6 @@ def _gemm():
648640
return tex.general_gemm(
649641
kernel_t,
650642
gemm_inp,
651-
workspace,
652643
out_dtype=torch.bfloat16,
653644
use_split_accumulator=te.module.base._2X_ACC_FPROP,
654645
ub=ub_obj,

transformer_engine/pytorch/cpu_offload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ def fwd_step(self) -> int:
471471
"""
472472
if self.num_of_fwds in [None, self.num_layers - 1]:
473473
# reset the offload synchronizer
474+
for layer_id in self.layer_states:
475+
self.layer_states[layer_id].release_all_memory()
474476
self.num_of_fwds = 0
475477
else:
476478
self.num_of_fwds += 1

transformer_engine/pytorch/distributed.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,13 @@ def _all_gather_fp8(
948948
if isinstance(inp, Float8Tensor):
949949
dtype = inp.dtype
950950
device = inp.device
951+
# Temporarily ensure rowwise usage for output tensor creation
952+
# since we're gathering rowwise data, not the transpose
953+
init_rowwise_usage = quantizer.rowwise_usage
954+
init_columnwise_usage = quantizer.columnwise_usage
955+
quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage)
951956
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
957+
quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage)
952958
elif isinstance(inp, Float8Tensor):
953959
out = inp.make_like(inp, shape=out_shape)
954960
out._data = torch.empty(

transformer_engine/pytorch/tensor/mxfp8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def make_empty(
134134
columnwise_data = None
135135
columnwise_scale_inv = None
136136
if self.columnwise_usage:
137-
columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
137+
columnwise_data = torch.empty(
138+
shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
139+
)
138140
columnwise_scale_inv = torch.empty(
139141
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
140142
round_up_to_nearest_multiple(shape[-1], 128),

0 commit comments

Comments
 (0)