Skip to content

Commit

Permalink
Merge branch 'release_v2.0' into release_v2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa authored Jan 30, 2025
2 parents 13273e4 + 5904a80 commit 3bff787
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _functional_forward(
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True, columnwise=False)
weight_quantizer.set_usage(rowwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
Expand Down Expand Up @@ -666,7 +666,7 @@ def _functional_backward(
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=True)
input_quantizer.set_usage(columnwise=True)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
Expand Down Expand Up @@ -705,7 +705,7 @@ def _functional_backward(
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True, columnwise=True)
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
Expand Down Expand Up @@ -833,6 +833,10 @@ def op_forward(
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:

# Check which grads are required
input_requires_grad = ctx.requires_grad and input_.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad

# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
Expand All @@ -841,6 +845,8 @@ def op_forward(
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:

# Get quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
if next_op is not None and next_op.num_quantizers("forward") > 0:
Expand All @@ -849,6 +855,12 @@ def op_forward(
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)

# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(columnwise=weight_requires_grad)
weight_quantizer.set_usage(columnwise=False)

# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
Expand Down Expand Up @@ -876,8 +888,8 @@ def op_forward(
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.dtype = dtype
ctx.input_requires_grad = input_.requires_grad
ctx.weight_requires_grad = self.weight.requires_grad
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
ctx.has_prev_op = prev_op is not None

return output
Expand Down

0 comments on commit 3bff787

Please sign in to comment.