Skip to content

Commit fecc23c

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent d474723 commit fecc23c

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

tests/pytorch/debug/test_log.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ def test_numerics(fp8_recipe, feature_dirs):
164164
num_quantizers=3,
165165
)
166166

167-
tensor = torch.zeros(1024, 1024).cuda()
168-
tensor[0, :] = 1000
167+
tensor = torch.randn(1024, 1024).cuda()
169168
quantizer = recipe_state.make_quantizers()[0]
170169
quantized_tensor = quantizer(tensor)
171170

@@ -189,14 +188,14 @@ def test_numerics(fp8_recipe, feature_dirs):
189188
underflows = float(line.split("value=")[1])
190189
expected = (
191190
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
192-
/ dequantized_tensor.numel()
191+
/ tensor.numel()
193192
* 100
194193
)
195194
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
196195
if "mse" in line:
197196
mse = float(line.split("value=")[1])
198197
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
199-
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
198+
assert mse == pytest.approx(expected.cpu(), abs=1e-4)
200199
if "overflows%" in line:
201200
overflows = float(line.split("value=")[1])
202201
expected = (

transformer_engine/debug/features/utils/stats_computation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,20 +211,20 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False):
211211
stats_to_num[stat_num] = len(stats_to_num)
212212
stats_to_num[stat_pct] = len(stats_to_num)
213213

214-
zero_values = torch.tensor([0, 127], device="cuda")
214+
zero_values = torch.tensor([0, 128], device="cuda")
215215

216216
STATS[stat_num] = (
217217
lambda x, aux_dict:
218-
aux_dict[recipe_name].get_data_tensors(
218+
torch.isin(aux_dict[recipe_name].get_data_tensors(
219219
rowwise_data=not columnwise, columnwise_data=columnwise
220-
).isin(zero_values).sum() - (x == 0).sum(),
220+
), zero_values).sum() - (x == 0).sum(),
221221
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
222222
)
223223
STATS[stat_pct] = (
224224
lambda x, aux_dict: (
225-
aux_dict[recipe_name].get_data_tensors(
225+
torch.isin(aux_dict[recipe_name].get_data_tensors(
226226
rowwise_data=not columnwise, columnwise_data=columnwise
227-
).isin(zero_values).sum() - (x == 0).sum())
227+
), zero_values).sum() - (x == 0).sum())
228228
/ aux_dict[recipe_name].numel()
229229
* 100,
230230
lambda buffers, _sn_num=stat_num: 100

0 commit comments

Comments
 (0)