We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fb6bf83 commit 063bd31Copy full SHA for 063bd31
vllm/model_executor/layers/quantization/fp8.py
@@ -988,7 +988,11 @@ def forward_hpu(
988
if self.quant_config.activation_scheme == "dynamic" and not self.block_quant:
989
x_fp8, x_scale = dynamic_quant(x)
990
991
- htorch.core.mark_step()
+ if torch._dynamo.is_compiling():
992
+ torch._dynamo.graph_break()
993
+ else:
994
+ htorch.core.mark_step()
995
+
996
if (self.padded_weights_buffer is None
997
or self.padded_weights_buffer.dtype != x.dtype
998
or self.padded_weights_buffer.device != x.device
0 commit comments