Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Reduce tensor dimensions in MXFP8 tests #1435

Open
wants to merge 2 commits into
base: release_v2.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from run_layer_with_overlap import _compare_tensors

SEQ_LEN, BATCH_SIZE = 16, 16
HIDDEN_SIZE = 64
SEQ_LEN, BATCH_SIZE = 32, 32
HIDDEN_SIZE = 128
NR_HEADS = 4
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
Expand Down Expand Up @@ -79,6 +79,8 @@ def main(argv=None, namespace=None):
parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args(argv, namespace)

QUANTIZATION = args.quantization

test_dict = [
test_linear,
test_layernorm,
Expand All @@ -87,14 +89,6 @@ def main(argv=None, namespace=None):
test_transformer_layer,
]

# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION == "mxfp8":
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
SEQ_LEN = 64
BATCH_SIZE = 64
HIDDEN_SIZE = 256

for test in test_dict:
test()
dist.destroy_process_group()
Expand Down Expand Up @@ -575,7 +569,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
"""
# Set parameter data type
params_dtype = kwargs.get("params_dtype", torch.float32)
FFN_HIDDEN_SIZE = {None: 32, "fp8": 64, "mxfp8": 256}[QUANTIZATION]
FFN_HIDDEN_SIZE = 128

# Create models
model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
Expand Down Expand Up @@ -665,7 +659,7 @@ def test_layernorm_mlp():
@run_distributed_test()
def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
params_dtype = kwargs.get("params_dtype", torch.float32)
FFN_HIDDEN_SIZE = {None: 32, "fp8": 64, "mxfp8": 256}[QUANTIZATION]
FFN_HIDDEN_SIZE = 128

model_single_node = te.TransformerLayer(
HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs
Expand Down
8 changes: 4 additions & 4 deletions tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def _test_reduce_scatter(

def _test_basic_linear(
*,
local_weight_shape: tuple[int, int] = (128, 128),
local_batch_size: int = 128,
local_weight_shape: tuple[int, int] = (16, 16),
local_batch_size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
Expand Down Expand Up @@ -459,8 +459,8 @@ def _test_basic_linear(
def _test_linear(
*,
bias: bool = True,
local_weight_shape: tuple[int, int] = (128, 128),
local_batch_size: int = 128,
local_weight_shape: tuple[int, int] = (16, 16),
local_batch_size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
Expand Down
50 changes: 21 additions & 29 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def maybe_skip_quantization(
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 128 != 0 or dims[-1] % 128 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 128")
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")

# Check if device is supported
if device is not None and torch.device(device).type != "cuda":
Expand Down Expand Up @@ -368,6 +368,7 @@ def test_fp8_scale_update(
def test_dtype_cast(
self,
*,
size: int = 32,
init_dtype: torch.dtype,
final_dtype: torch.dtype,
device: torch.device = "cuda",
Expand All @@ -379,11 +380,6 @@ def test_dtype_cast(
maybe_skip_quantization(quantization, device=device)
with_quantization = quantization is not None

# Data dimensions
size = 16
if quantization == "mxfp8":
size = 128

# Random data
dtype = torch.float32
if torch.float16 in (init_dtype, final_dtype):
Expand Down Expand Up @@ -437,6 +433,7 @@ def test_dtype_cast(
def test_pyt_autocast(
self,
*,
size: int = 32,
model_dtype: torch.dtype,
autocast_dtype: torch.dtype,
device: torch.device = "cuda",
Expand All @@ -450,11 +447,6 @@ def test_pyt_autocast(
quantized_compute = quantization is not None
maybe_skip_quantization(quantization)

# Data dimensions
size = 16
if quantization == "mxfp8":
size = 128

# Construct operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
Expand Down Expand Up @@ -692,7 +684,7 @@ def test_bias(
def test_quantize(
self,
*,
in_shape: Iterable[int] = (128, 128),
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
quantization: str,
Expand Down Expand Up @@ -859,8 +851,8 @@ def _test_basic_linear(
)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)

@pytest.mark.parametrize("weight_shape", ((128, 128), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 4, 8, -1)))
@pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
Expand Down Expand Up @@ -921,8 +913,8 @@ def test_linear(
self,
*,
bias: bool,
weight_shape: tuple[int, int] = (128, 128),
in_shape: Iterable[int] = (128, -1),
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
Expand Down Expand Up @@ -1012,8 +1004,8 @@ def test_linear(
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)

@pytest.mark.parametrize("weight_shape", ((7, 2), (128,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 64, -1)))
@pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
Expand Down Expand Up @@ -1182,8 +1174,8 @@ def test_layer_norm_autocast(
torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype))
torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype))

@pytest.mark.parametrize("weight_shape", ((19,), (128,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 64, -1)))
@pytest.mark.parametrize("weight_shape", ((19,), (64,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
Expand Down Expand Up @@ -1395,7 +1387,7 @@ def test_make_extra_output(
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (128, 1, 128)))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
def test_activation(
Expand Down Expand Up @@ -1491,7 +1483,7 @@ def test_activation(
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (128, 128),
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
Expand Down Expand Up @@ -1560,8 +1552,8 @@ def setup_class(cls) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

@pytest.mark.parametrize("weight_shape", ((128, 128), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (128, -1)))
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
Expand Down Expand Up @@ -1678,8 +1670,8 @@ def test_forward_linear_bias_add(
self,
*,
bias: bool,
weight_shape: tuple[int, int] = (128, 128),
in_shape: Iterable[int] = (128, -1),
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
Expand Down Expand Up @@ -1791,8 +1783,8 @@ def test_forward_linear_bias_add(
def test_backward_linear_add(
self,
*,
weight_shape: tuple[int, int] = (128, 128),
in_shape: Iterable[int] = (128, -1),
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
Expand Down
Loading