Skip to content

Commit c65e2e9

Browse files
[PyTorch Debug] Add max_blockwise_dynamic_range stats (NVIDIA#2137)
* code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4b84bc7 commit c65e2e9

File tree

4 files changed

+388
-8
lines changed

4 files changed

+388
-8
lines changed

tests/pytorch/debug/test_log.py

Lines changed: 195 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
)
1919
from transformer_engine.pytorch.quantization import RecipeState
2020
from transformer_engine.debug.pytorch.debug_state import TEDebugState
21-
21+
from transformer_engine.debug.features.utils.stats_computation import (
22+
compute_max_blockwise_dynamic_range,
23+
BlockwiseDynamicRangeStat,
24+
)
25+
import math
2226

2327
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
2428
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
@@ -154,7 +158,7 @@ def test_sanity(feature_dirs):
154158

155159

156160
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
157-
def test_numerics(fp8_recipe, feature_dirs):
161+
def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
158162
if not fp8_available:
159163
pytest.skip(reason_for_no_fp8)
160164
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs):
210214
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
211215

212216

217+
LOG_HIGH_PRECISION_CONFIG = """
218+
log:
219+
layers:
220+
layer_name_regex_pattern: .*
221+
enabled:
222+
True
223+
transformer_engine:
224+
LogTensorStats:
225+
enabled: True
226+
stats:
227+
- dynamic_range
228+
- max_blockwise_dynamic_range:
229+
block_size: 4
230+
dims: 1
231+
- max_blockwise_dynamic_range:
232+
block_size: 4
233+
dims: 2
234+
tensors: [activation, gradient, weight]
235+
freq: 2
236+
start_step: 0
237+
end_step: 10
238+
"""
239+
240+
241+
@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"])
242+
def test_log_stats_numerics(feature_dirs, tensor_name):
243+
"""Check correctness of dynamic range and max blockwise dynamic range stats.
244+
245+
Tests different tensor types:
246+
- activation/weight: use both orientations (rowwise + columnwise), takes max
247+
- gradient/dgrad: use single orientation (rowwise only)
248+
"""
249+
log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG
250+
251+
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
252+
# There is 1024 x 1024 tensor with very small epsilon values in almost all elements,
253+
# one row of large value A and three rows of large value B.
254+
epsilon = 1e-10
255+
A = 1000
256+
B = 50
257+
tensor = torch.zeros(1024, 1024).cuda() + epsilon
258+
tensor[0, :] = A
259+
tensor[1:4, :] = B
260+
261+
debug_api.transformer_engine.inspect_tensor(
262+
layer_name="layer_name",
263+
tensor_name=tensor_name,
264+
iteration=0,
265+
tp_group=None,
266+
tensor=tensor,
267+
quantizer=None,
268+
rowwise_quantized_tensor=None,
269+
columnwise_quantized_tensor=None,
270+
)
271+
debug_api.step()
272+
273+
output = read_log(log_dir)
274+
275+
max_over_orientations = tensor_name in ["activation", "weight"]
276+
max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else ""
277+
278+
# Track which stats were found to ensure all are present
279+
found_dims_1 = False
280+
found_dims_2 = False
281+
found_dynamic_range = False
282+
283+
for line in output.splitlines():
284+
if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line:
285+
max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1])
286+
if max_over_orientations:
287+
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
288+
expected = math.log2(A) - math.log2(B)
289+
else:
290+
# Rowwise blocks have uniform values -> dynamic_range = 0
291+
expected = 0
292+
assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(
293+
expected, abs=1e-4
294+
)
295+
found_dims_1 = True
296+
elif (
297+
f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line
298+
):
299+
max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1])
300+
# For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows
301+
expected = math.log2(A) - math.log2(B)
302+
assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(
303+
expected, abs=1e-4
304+
)
305+
found_dims_2 = True
306+
elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line:
307+
dynamic_range = float(line.split("value=")[1])
308+
expected = math.log2(A) - math.log2(epsilon)
309+
assert dynamic_range == pytest.approx(expected, abs=1e-4)
310+
found_dynamic_range = True
311+
312+
# Ensure all expected stats were found in the output
313+
assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output"
314+
assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output"
315+
assert found_dynamic_range, "dynamic_range not found in output"
316+
317+
213318
@pytest.mark.parametrize("layer", ["linear", "transformer"])
214319
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
215320
if not fp8_available:
@@ -278,3 +383,91 @@ def test_log_grouped_gemm(feature_dirs):
278383
assert "gemm_0" in output, "gemm0 not found in output"
279384
assert "gemm_1" in output, "gemm1 not found in output"
280385
assert "gemm_2" in output, "gemm2 not found in output"
386+
387+
def test_compute_max_blockwise_dynamic_range_direct():
388+
"""Direct unit test for compute_max_blockwise_dynamic_range function.
389+
390+
Tests the function with various configurations to ensure correct behavior
391+
for different block sizes, dimensions, and orientation settings.
392+
"""
393+
# Create test tensor with uniform rows but mixed columns
394+
# Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01
395+
epsilon = 0.01
396+
A = 1000.0
397+
B = 50.0
398+
tensor = torch.zeros(1024, 1024).cuda() + epsilon
399+
tensor[0, :] = A
400+
tensor[1:4, :] = B
401+
402+
# Test 1: dims=1, max_over_orientations=False (rowwise only)
403+
# Rowwise blocks have uniform values -> dynamic_range should be 0
404+
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False)
405+
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
406+
assert result.item() == pytest.approx(
407+
0.0, abs=1e-4
408+
), "Rowwise 1D blocks with uniform values should have dynamic_range=0"
409+
410+
# Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise)
411+
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
412+
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
413+
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
414+
expected = math.log2(A) - math.log2(B)
415+
assert result.item() == pytest.approx(expected, abs=1e-4), (
416+
f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got"
417+
f" {result.item()}"
418+
)
419+
420+
# Test 3: dims=2, block_size=4 (4x4 tiles)
421+
# 2D blocks span multiple rows -> always have mixed values
422+
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False)
423+
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
424+
expected = math.log2(A) - math.log2(B)
425+
assert result.item() == pytest.approx(expected, abs=1e-4), (
426+
f"2D blocks should capture mixed values from different rows, expected {expected}, got"
427+
f" {result.item()}"
428+
)
429+
430+
# Test 4: Different block size
431+
# With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon]
432+
# So max=A, min=epsilon (not B anymore)
433+
stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True)
434+
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
435+
expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B
436+
assert result.item() == pytest.approx(
437+
expected, abs=1e-4
438+
), f"Block size 8 should work correctly, expected {expected}, got {result.item()}"
439+
440+
# Test 5: Tensor with all uniform values -> dynamic_range should be 0
441+
uniform_tensor = torch.ones(64, 64).cuda() * 42.0
442+
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
443+
result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config)
444+
assert result.item() == pytest.approx(
445+
0.0, abs=1e-4
446+
), "Uniform tensor should have dynamic_range=0"
447+
448+
# Test 6: 3D tensor flattening validation using 2D/3D comparison
449+
# Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2
450+
# Then reshape to 3D and compute again - results should match if flattening is correct
451+
tensor_2d = torch.tensor(
452+
[
453+
[1.0, 1.0, 10.0, 10.0],
454+
[1.0, 1.0, 10.0, 10.0],
455+
[100.0, 100.0, 1000.0, 1000.0],
456+
[100.0, 100.0, 1000.0, 1000.0],
457+
]
458+
).cuda()
459+
460+
# Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100)
461+
stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False)
462+
result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config)
463+
464+
# Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct
465+
tensor_3d = tensor_2d.reshape(2, 2, 4)
466+
result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config)
467+
468+
assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), (
469+
"3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got"
470+
f" 2D={result_2d.item()}, 3D={result_3d.item()}"
471+
)
472+
473+
print("All direct tests for compute_max_blockwise_dynamic_range passed!")

transformer_engine/debug/features/log_tensor_stats.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
"""LogTensorStats Feature support for nvidia-dlframework-inspect"""
66

7-
from typing import Dict, Optional
7+
from typing import Dict, Optional, List
88

99
import torch
1010

@@ -19,6 +19,10 @@
1919
from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
2020
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
2121
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
22+
from transformer_engine.debug.features.utils.stats_computation import (
23+
add_max_blockwise_dynamic_range_stats,
24+
BlockwiseDynamicRangeStat,
25+
)
2226

2327

2428
@Registry.register_feature(namespace="transformer_engine")
@@ -44,7 +48,14 @@ class LogTensorStats(BaseLogTensorStats):
4448
- l1_norm
4549
- l2_norm
4650
- cur_amax – maximal absolute value of a tensor,
47-
- dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)`
51+
- dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)`
52+
- max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor.
53+
If tensor and its transpose is needed in training, this stat is computed for both orientations and the maximum is returned.
54+
For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile.
55+
56+
- block_size: int, default = 32
57+
- dims: int, default = 1, allowed values are 1 and 2
58+
4859
tensors/tensors_struct: List[str]
4960
list of tensors to log
5061
@@ -88,6 +99,60 @@ class LogTensorStats(BaseLogTensorStats):
8899
stats: [dynamic_range]
89100
"""
90101

102+
def _is_supported_stat(self, stat: str | Dict):
103+
"""Returns True if the stat is supported by this feature, False otherwise."""
104+
if isinstance(stat, dict):
105+
stat_name = list(stat.keys())[0]
106+
if stat_name == "max_blockwise_dynamic_range":
107+
stat_dict = stat[stat_name]
108+
if not isinstance(stat_dict, dict):
109+
return False
110+
# Ensure only supported keys are present
111+
allowed_keys = {"block_size", "dims"}
112+
if any(k not in allowed_keys for k in stat_dict.keys()):
113+
return False
114+
block_size = stat_dict.get("block_size", 32)
115+
dims = stat_dict.get("dims", 1)
116+
# Type and value validation
117+
if not isinstance(block_size, int) or not isinstance(dims, int):
118+
return False
119+
if block_size > 0 and dims in [1, 2]:
120+
return True
121+
return False
122+
return stat in BaseLogTensorStats._get_supported_stats_list(None) | {
123+
"cur_amax",
124+
"dynamic_range",
125+
}
126+
127+
def _parse_max_blockwise_dynamic_range_stats(
128+
self, stats: List[str | Dict], tensor_name: str
129+
) -> List[str | BlockwiseDynamicRangeStat]:
130+
"""
131+
Adds all max_blockwise_dynamic_range stats to the stat computation logic.
132+
Changes the types of the stats from Dict to BlockwiseDynamicRangeStat named tuple,
133+
for other stats nothing is changed.
134+
135+
For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}],
136+
it will be changed to [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=True)]
137+
or [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=False)] depending on tensor_name.
138+
139+
"""
140+
max_over_orientations = tensor_name in ["activation", "weight"]
141+
parsed_stats = []
142+
for stat in stats:
143+
if isinstance(stat, dict):
144+
block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32)
145+
dims = stat["max_blockwise_dynamic_range"].get("dims", 1)
146+
147+
# Register stat and return the named tuple
148+
parsed_stat = add_max_blockwise_dynamic_range_stats(
149+
block_size, dims, max_over_orientations
150+
)
151+
parsed_stats.append(parsed_stat)
152+
else:
153+
parsed_stats.append(stat)
154+
return parsed_stats
155+
91156
def _get_supported_stats_list(self):
92157
"""Returns stats this feature can log."""
93158
return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"}
@@ -147,14 +212,16 @@ def inspect_tensor(
147212
)
148213

149214
for stat in config["stats"]:
150-
assert (
151-
stat in self._get_supported_stats_list()
215+
assert self._is_supported_stat(
216+
stat
152217
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
153218

219+
stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"], tensor_name)
220+
154221
STATS_BUFFERS.try_add_buffer(
155222
layer_name=layer_name,
156223
tensor_name=tensor_name,
157-
stats=config["stats"],
224+
stats=stats,
158225
options=options,
159226
reduction_group=reduction_group,
160227
reduce_within_microbatch=reduce_within_microbatch,

transformer_engine/debug/features/utils/stats_buffer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,12 @@ def log(self):
130130
for stat_name in self.stats_to_log:
131131
combiner = STATS[stat_name][1]
132132
stat_value = combiner(gathered_helper_stats)
133+
134+
# Convert stat key to string for logging (uses __str__ for named tuples)
135+
stat_name_str = str(stat_name)
136+
133137
MetricLogger.log_scalar(
134-
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
138+
f"{self.layer_name}_{self.tensor_name}_{stat_name_str}", stat_value, self.iteration
135139
)
136140
output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = (
137141
stat_value # for debugging purposes

0 commit comments

Comments
 (0)