|
18 | 18 | ) |
19 | 19 | from transformer_engine.pytorch.quantization import RecipeState |
20 | 20 | from transformer_engine.debug.pytorch.debug_state import TEDebugState |
21 | | - |
| 21 | +from transformer_engine.debug.features.utils.stats_computation import ( |
| 22 | + compute_max_blockwise_dynamic_range, |
| 23 | + BlockwiseDynamicRangeStat, |
| 24 | +) |
| 25 | +import math |
22 | 26 |
|
23 | 27 | fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) |
24 | 28 | mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) |
@@ -154,7 +158,7 @@ def test_sanity(feature_dirs): |
154 | 158 |
|
155 | 159 |
|
156 | 160 | @pytest.mark.parametrize("fp8_recipe", fp8_recipes) |
157 | | -def test_numerics(fp8_recipe, feature_dirs): |
| 161 | +def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): |
158 | 162 | if not fp8_available: |
159 | 163 | pytest.skip(reason_for_no_fp8) |
160 | 164 | if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): |
@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs): |
210 | 214 | assert overflows == pytest.approx(expected.cpu(), abs=1e-4) |
211 | 215 |
|
212 | 216 |
|
| 217 | +LOG_HIGH_PRECISION_CONFIG = """ |
| 218 | +log: |
| 219 | + layers: |
| 220 | + layer_name_regex_pattern: .* |
| 221 | + enabled: |
| 222 | + True |
| 223 | + transformer_engine: |
| 224 | + LogTensorStats: |
| 225 | + enabled: True |
| 226 | + stats: |
| 227 | + - dynamic_range |
| 228 | + - max_blockwise_dynamic_range: |
| 229 | + block_size: 4 |
| 230 | + dims: 1 |
| 231 | + - max_blockwise_dynamic_range: |
| 232 | + block_size: 4 |
| 233 | + dims: 2 |
| 234 | + tensors: [activation, gradient, weight] |
| 235 | + freq: 2 |
| 236 | + start_step: 0 |
| 237 | + end_step: 10 |
| 238 | +""" |
| 239 | + |
| 240 | + |
| 241 | +@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"]) |
| 242 | +def test_log_stats_numerics(feature_dirs, tensor_name): |
| 243 | + """Check correctness of dynamic range and max blockwise dynamic range stats. |
| 244 | +
|
| 245 | + Tests different tensor types: |
| 246 | + - activation/weight: use both orientations (rowwise + columnwise), takes max |
| 247 | + - gradient/dgrad: use single orientation (rowwise only) |
| 248 | + """ |
| 249 | + log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG |
| 250 | + |
| 251 | + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: |
| 252 | + # There is 1024 x 1024 tensor with very small epsilon values in almost all elements, |
| 253 | + # one row of large value A and three rows of large value B. |
| 254 | + epsilon = 1e-10 |
| 255 | + A = 1000 |
| 256 | + B = 50 |
| 257 | + tensor = torch.zeros(1024, 1024).cuda() + epsilon |
| 258 | + tensor[0, :] = A |
| 259 | + tensor[1:4, :] = B |
| 260 | + |
| 261 | + debug_api.transformer_engine.inspect_tensor( |
| 262 | + layer_name="layer_name", |
| 263 | + tensor_name=tensor_name, |
| 264 | + iteration=0, |
| 265 | + tp_group=None, |
| 266 | + tensor=tensor, |
| 267 | + quantizer=None, |
| 268 | + rowwise_quantized_tensor=None, |
| 269 | + columnwise_quantized_tensor=None, |
| 270 | + ) |
| 271 | + debug_api.step() |
| 272 | + |
| 273 | + output = read_log(log_dir) |
| 274 | + |
| 275 | + max_over_orientations = tensor_name in ["activation", "weight"] |
| 276 | + max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else "" |
| 277 | + |
| 278 | + # Track which stats were found to ensure all are present |
| 279 | + found_dims_1 = False |
| 280 | + found_dims_2 = False |
| 281 | + found_dynamic_range = False |
| 282 | + |
| 283 | + for line in output.splitlines(): |
| 284 | + if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line: |
| 285 | + max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) |
| 286 | + if max_over_orientations: |
| 287 | + # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) |
| 288 | + expected = math.log2(A) - math.log2(B) |
| 289 | + else: |
| 290 | + # Rowwise blocks have uniform values -> dynamic_range = 0 |
| 291 | + expected = 0 |
| 292 | + assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx( |
| 293 | + expected, abs=1e-4 |
| 294 | + ) |
| 295 | + found_dims_1 = True |
| 296 | + elif ( |
| 297 | + f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line |
| 298 | + ): |
| 299 | + max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) |
| 300 | + # For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows |
| 301 | + expected = math.log2(A) - math.log2(B) |
| 302 | + assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( |
| 303 | + expected, abs=1e-4 |
| 304 | + ) |
| 305 | + found_dims_2 = True |
| 306 | + elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line: |
| 307 | + dynamic_range = float(line.split("value=")[1]) |
| 308 | + expected = math.log2(A) - math.log2(epsilon) |
| 309 | + assert dynamic_range == pytest.approx(expected, abs=1e-4) |
| 310 | + found_dynamic_range = True |
| 311 | + |
| 312 | + # Ensure all expected stats were found in the output |
| 313 | + assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output" |
| 314 | + assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output" |
| 315 | + assert found_dynamic_range, "dynamic_range not found in output" |
| 316 | + |
| 317 | + |
213 | 318 | @pytest.mark.parametrize("layer", ["linear", "transformer"]) |
214 | 319 | def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): |
215 | 320 | if not fp8_available: |
@@ -278,3 +383,91 @@ def test_log_grouped_gemm(feature_dirs): |
278 | 383 | assert "gemm_0" in output, "gemm0 not found in output" |
279 | 384 | assert "gemm_1" in output, "gemm1 not found in output" |
280 | 385 | assert "gemm_2" in output, "gemm2 not found in output" |
| 386 | + |
| 387 | +def test_compute_max_blockwise_dynamic_range_direct(): |
| 388 | + """Direct unit test for compute_max_blockwise_dynamic_range function. |
| 389 | +
|
| 390 | + Tests the function with various configurations to ensure correct behavior |
| 391 | + for different block sizes, dimensions, and orientation settings. |
| 392 | + """ |
| 393 | + # Create test tensor with uniform rows but mixed columns |
| 394 | + # Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01 |
| 395 | + epsilon = 0.01 |
| 396 | + A = 1000.0 |
| 397 | + B = 50.0 |
| 398 | + tensor = torch.zeros(1024, 1024).cuda() + epsilon |
| 399 | + tensor[0, :] = A |
| 400 | + tensor[1:4, :] = B |
| 401 | + |
| 402 | + # Test 1: dims=1, max_over_orientations=False (rowwise only) |
| 403 | + # Rowwise blocks have uniform values -> dynamic_range should be 0 |
| 404 | + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False) |
| 405 | + result = compute_max_blockwise_dynamic_range(tensor, stat_config) |
| 406 | + assert result.item() == pytest.approx( |
| 407 | + 0.0, abs=1e-4 |
| 408 | + ), "Rowwise 1D blocks with uniform values should have dynamic_range=0" |
| 409 | + |
| 410 | + # Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise) |
| 411 | + # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) |
| 412 | + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) |
| 413 | + result = compute_max_blockwise_dynamic_range(tensor, stat_config) |
| 414 | + expected = math.log2(A) - math.log2(B) |
| 415 | + assert result.item() == pytest.approx(expected, abs=1e-4), ( |
| 416 | + f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got" |
| 417 | + f" {result.item()}" |
| 418 | + ) |
| 419 | + |
| 420 | + # Test 3: dims=2, block_size=4 (4x4 tiles) |
| 421 | + # 2D blocks span multiple rows -> always have mixed values |
| 422 | + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False) |
| 423 | + result = compute_max_blockwise_dynamic_range(tensor, stat_config) |
| 424 | + expected = math.log2(A) - math.log2(B) |
| 425 | + assert result.item() == pytest.approx(expected, abs=1e-4), ( |
| 426 | + f"2D blocks should capture mixed values from different rows, expected {expected}, got" |
| 427 | + f" {result.item()}" |
| 428 | + ) |
| 429 | + |
| 430 | + # Test 4: Different block size |
| 431 | + # With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon] |
| 432 | + # So max=A, min=epsilon (not B anymore) |
| 433 | + stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True) |
| 434 | + result = compute_max_blockwise_dynamic_range(tensor, stat_config) |
| 435 | + expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B |
| 436 | + assert result.item() == pytest.approx( |
| 437 | + expected, abs=1e-4 |
| 438 | + ), f"Block size 8 should work correctly, expected {expected}, got {result.item()}" |
| 439 | + |
| 440 | + # Test 5: Tensor with all uniform values -> dynamic_range should be 0 |
| 441 | + uniform_tensor = torch.ones(64, 64).cuda() * 42.0 |
| 442 | + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) |
| 443 | + result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config) |
| 444 | + assert result.item() == pytest.approx( |
| 445 | + 0.0, abs=1e-4 |
| 446 | + ), "Uniform tensor should have dynamic_range=0" |
| 447 | + |
| 448 | + # Test 6: 3D tensor flattening validation using 2D/3D comparison |
| 449 | + # Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2 |
| 450 | + # Then reshape to 3D and compute again - results should match if flattening is correct |
| 451 | + tensor_2d = torch.tensor( |
| 452 | + [ |
| 453 | + [1.0, 1.0, 10.0, 10.0], |
| 454 | + [1.0, 1.0, 10.0, 10.0], |
| 455 | + [100.0, 100.0, 1000.0, 1000.0], |
| 456 | + [100.0, 100.0, 1000.0, 1000.0], |
| 457 | + ] |
| 458 | + ).cuda() |
| 459 | + |
| 460 | + # Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100) |
| 461 | + stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False) |
| 462 | + result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config) |
| 463 | + |
| 464 | + # Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct |
| 465 | + tensor_3d = tensor_2d.reshape(2, 2, 4) |
| 466 | + result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config) |
| 467 | + |
| 468 | + assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), ( |
| 469 | + "3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got" |
| 470 | + f" 2D={result_2d.item()}, 3D={result_3d.item()}" |
| 471 | + ) |
| 472 | + |
| 473 | + print("All direct tests for compute_max_blockwise_dynamic_range passed!") |
0 commit comments