Skip to content

Commit

Permalink
Feat (example/common): Added groupwise, float scaled OCP option (#1190)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Feb 14, 2025
1 parent 874ce50 commit 3d50dff
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPDynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFloatMSE
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint
Expand Down Expand Up @@ -133,7 +135,9 @@
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloat}},
'sym': Fp8e4m3OCPWeightPerChannelFloat},
'per_group': {
'sym': Fp8e4m3OCPWeightSymmetricGroupQuant}},
'mse': {
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
Expand Down Expand Up @@ -213,7 +217,9 @@
'float_scale': {
'stats': {
'per_row': {
'sym': FP8e4m3OCPDynamicActPerRowFloat}}},
'sym': FP8e4m3OCPDynamicActPerRowFloat},
'per_group': {
'sym': Fp8e4m3OCPDynamicActPerGroupFloat}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down
19 changes: 19 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ class FP8e4m3OCPDynamicActPerRowFloat(FP8e4m3OCPDynamicActPerRowFixedPoint):
restrict_scaling_type = RestrictValueType.FP


class Fp8e4m3OCPDynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3OCPActPerTensorFloat):
"""
Symmetric quantizer with per group scale.
"""
proxy_class = GroupwiseActFloatQuantProxyFromInjector
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_per_output_type = ScalingPerOutputType.GROUP
scaling_stats_op = 'min_max'


class Fp8e4m3OCPWeightSymmetricGroupQuant(Fp8e4m3OCPWeightPerChannelFloat):
"""
Block / group / vector signed symmetric e4m3 OCP weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
proxy_class = GroupwiseWeightFloatQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP


class Fp8e4m3OCPWeightPerChannelFixedPointMSE(MSESymmetricScale,
PerChannelPoTScaling8bit,
Fp8e4m3OCPWeightPerChannelFloat):
Expand Down

0 comments on commit 3d50dff

Please sign in to comment.