Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamlining of Scaled Dot-Product Attention #901

Open
wants to merge 18 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
fdd89a6
[Streamline] Prefer AbsorbSignBiasIntoMultiThreshold transform
iksnagreb Sep 30, 2023
be33bbc
[Streamline] Refactor MoveScalarMulPastMatMul to handle join-node matmul
iksnagreb Sep 30, 2023
09c1993
Remove misplaced/outdated comment
iksnagreb Sep 30, 2023
9dade0c
[Streamline] Soften initializer tests in Absorb1BitMulIntoMatMul/Conv
iksnagreb Sep 30, 2023
8bae5d7
Address some linting issues
iksnagreb Oct 19, 2023
b22ebe3
[Tests] Add test for MoveScalarMulPastMatMul handling join nodes
iksnagreb Oct 19, 2023
c10fa1d
[Deps] Update qonnx version to include FoldTransposeIntoQuantInit fix
iksnagreb Oct 27, 2023
475a27b
[Streamline] Fix FoldQuantWeights input order and shape annotations
iksnagreb Nov 13, 2023
bd6a8f8
[Streamline] Fix AbsorbAddIntoMultiThreshold assumed input order
iksnagreb Nov 13, 2023
1f7dd4c
[Streamline] Add support for Slice to MoveScalarLinearPastInvariants
iksnagreb Nov 15, 2023
b3e50d7
[Streamline] Absorb1BitMulIntoMatMul/Conv does not handle fork-nodes
iksnagreb Nov 17, 2023
0413368
[Deps] Temporarily switch qonnx to my fork including necessary fixes
iksnagreb Nov 17, 2023
2bf7949
Make quantized activation handlers data layout aware
iksnagreb Nov 20, 2023
8783fd4
[Deps] Update qonnx
iksnagreb Nov 20, 2023
2bf37f1
[Deps] Update qonnx
iksnagreb Dec 13, 2023
a4fc498
[Deps] Update qonnx
iksnagreb Mar 13, 2024
6c56382
Fix some typos
iksnagreb Apr 4, 2024
15a9daa
Merge remote-tracking branch 'xilinx/dev' into feature/attention-stre…
iksnagreb Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fetch-repos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

QONNX_COMMIT="2281a777d84aa5cbd7469085c2e534fb4a03ccf9"
QONNX_COMMIT="1a4957ebf2aaf139217fd56109386d4518dd6127"
FINN_EXP_COMMIT="0724be21111a21f0d81a072fccc1c446e053f851"
BREVITAS_COMMIT="d4834bd2a0fad3c1fbc0ff7e1346d5dcb3797ea4"
PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1"
Expand All @@ -40,7 +40,7 @@ RFSOC4x2_BDF_COMMIT="13fb6f6c02c7dfd7e4b336b18b959ad5115db696"
KV260_BDF_COMMIT="98e0d3efc901f0b974006bc4370c2a7ad8856c79"
EXP_BOARD_FILES_MD5="226ca927a16ea4ce579f1332675e9e9a"

QONNX_URL="https://github.com/fastmachinelearning/qonnx.git"
QONNX_URL="https://github.com/iksnagreb/qonnx.git"
FINN_EXP_URL="https://github.com/Xilinx/finn-experimental.git"
BREVITAS_URL="https://github.com/Xilinx/brevitas.git"
PYVERILATOR_URL="https://github.com/maltanar/pyverilator.git"
Expand Down
35 changes: 29 additions & 6 deletions src/finn/transformation/qonnx/fold_quant_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def apply(self, model):
mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
mul_shape, # Note: This shape is known exactly as
# it is an initializer with known shape
)
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, scale)
Expand All @@ -168,7 +169,9 @@ def apply(self, model):
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
None, # Note: Explicitly delete the shape
# annotation to be redone by the next shape
# inference
)
graph.value_info.append(act_mul_tensor)
successor.output[0] = act_mul_tensor.name
Expand All @@ -186,19 +189,37 @@ def apply(self, model):
div_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
mul_shape,
None, # Note: Explicitly delete the shape
# annotation to be redone by the next shape
# inference
)
graph.value_info.append(div_tensor)
model.set_initializer(div_tensor.name, scale)

succ_input_name = successor.input[0]
# Detect which input of the add-like successor is
# fed by the quantizer node to select the other
# branch to insert the scale factor
if successor.input[0] == node_out:
succ_input_name = successor.input[1]
else:
succ_input_name = successor.input[0]

act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
None, # Note: Explicitly delete the shape
# annotation to be redone by the next shape
# inference
)
graph.value_info.append(act_mul_tensor)
successor.input[0] = act_mul_tensor.name

# Detect which input of the add-like successor is
# fed by the quantizer node to select the other
# branch to insert the scale factor
if successor.input[0] == node_out:
successor.input[1] = act_mul_tensor.name
else:
successor.input[0] = act_mul_tensor.name

div_node = helper.make_node(
"Div",
Expand All @@ -210,6 +231,8 @@ def apply(self, model):
# remove old node
graph.node.remove(n)
graph_modified = True
# Note: Running shape inference is necessary as shape
# annotations have been deleted above
model = model.transform(InferShapes())
return (model, graph_modified)
return (model, graph_modified)
72 changes: 61 additions & 11 deletions src/finn/transformation/qonnx/qonnx_activation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from abc import ABC, abstractmethod
from onnx import TensorProto, helper
from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -70,7 +70,7 @@ def _check_compatibility(self):
@abstractmethod
def _calculate_act_bias(self):
"""Calculate the activation bias,
which is introduced as an Add node behind the MultiTrheshold node.
which is introduced as an Add node behind the MultiThreshold node.
"""
raise NotImplementedError()

Expand All @@ -82,7 +82,7 @@ def _calculate_thresholds(self):
@abstractmethod
def _calculate_act_scale(self):
"""Calculate the activation scale,
which is indroduced as a Mul node behind the Add node
which is introduced as a Mul node behind the Add node
for the activation bias.
"""
raise NotImplementedError()
Expand Down Expand Up @@ -157,7 +157,7 @@ def replace_quant_node(self):
# Set scale and bias
# If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
# behind the MultiTrheshold nodes.
# behind the MultiThreshold nodes.
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant":
Expand Down Expand Up @@ -355,7 +355,7 @@ def _calculate_thresholds(self):
act_node = self._model.find_direct_predecessors(self._q_node)
act_node = act_node[0]
if act_node.op_type == "Relu":
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
# onnx/finn/handler/act.py#L21
num_distinct_values = 2**bit_width
Expand Down Expand Up @@ -395,8 +395,34 @@ def _calculate_thresholds(self):
else:
thresholds[c][t] = step / selu_scale

# First try to consider the tensor layout of the output for determining
# the number of output channels
layout = self._model.get_tensor_layout(self._q_node.output[0])
# If there is a layout annotation, use this to determine the index of
# the channel dimension
if layout is not None and "C" in layout:
# Lookup the index in list
cdim = layout.index("C")
# If no layout has been annotated or there is no channel dimension, fall
# back to the previous default assumption
else:
# Assume the channels to be in axis 1
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.output[0]}:"
f" Assuming channel dimension at index {cdim}"
)

# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]

assert (
thresholds.shape[0] == 1 or thresholds.shape[
0] == num_output_channels
), """Quant node cannot be converted to MultiThreshold because only
per tensor or per channel quantization supported."""

final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
Expand All @@ -417,12 +443,12 @@ def _remove_activation_node(self, multi_threshold_node):
act_node = self._model.find_direct_predecessors(self._q_node)
if act_node is None:
raise RuntimeError(
"For handling of Relu activations a predecesor to " "the Quant node must exist."
"For handling of Relu activations a predecessor to " "the Quant node must exist."
)
act_node = act_node[0]
if act_node.op_type not in self.valid_predecessor_op_types():
raise RuntimeError(
"The predecesor of the Quant node must be Relu or Selu for handling "
"The predecessor of the Quant node must be Relu or Selu for handling "
"of activations."
)

Expand Down Expand Up @@ -509,7 +535,7 @@ def _calculate_thresholds(self):
else:
raise RuntimeError("Got an unexpected quantizer node type")

# Calculate thersholds, see: https://github.com/Xilinx/brevitas/
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
# export/onnx/finn/handler/act.py#L76
if bit_width == 1.0:
Expand Down Expand Up @@ -537,13 +563,37 @@ def _calculate_thresholds(self):
for t in range(num_thresholds):
thresholds[c][t] = min_threshold[c] + step[c] * t

# currently only per tensor or per channel quantization is supported
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
# First try to consider the tensor layout of the output for
# determining the number of output channels
layout = self._model.get_tensor_layout(self._q_node.output[0])
# If there is a layout annotation, use this to determine the index
# of the channel dimension
if layout is not None and "C" in layout:
# Lookup the index in list
cdim = layout.index("C")
# If no layout has been annotated or there is no channel dimension,
# fall back to the previous default assumption
else:
# Assume the channels to be in axis 1
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.output[0]}:"
f" Assuming channel dimension at index {cdim}"
)

# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]

assert (
thresholds.shape[0] == 1 or thresholds.shape[0] == num_output_channels
), """Quant node cannot be converted to MultiThreshold because only
per tensor or per channel quantization supported."""

final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)

return thresholds

def _calculate_act_scale(self):
Expand Down
2 changes: 1 addition & 1 deletion src/finn/transformation/streamline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def apply(self, model):
BatchNormToAffine(),
ConvertSignToThres(),
MoveMulPastMaxPool(),
MoveScalarLinearPastInvariants(),
AbsorbSignBiasIntoMultiThreshold(),
MoveScalarLinearPastInvariants(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveAddPastConv(),
Expand Down
Loading
Loading