Skip to content

MatMulInteger #2846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_mul_scalar(lhs, rhs)
}

fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_matmul(lhs, rhs)
}

fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
B::int_div(lhs, rhs)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
}

fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!()
}

fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
}
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-cubecl/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ where
numeric::mul_scalar::<R, I>(lhs, rhs)
}

fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!()
}
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
numeric::div::<R, I>(lhs, rhs)
}
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!()
}
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
binary_int_ops!(DivOps, B::int_div);

Expand Down
15 changes: 9 additions & 6 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ use super::{
dropout::DropoutNode, expand::ExpandNode, floor::FloorNode, gather::GatherNode,
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode,
pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode,
random_uniform_like::RandomUniformLikeNode, range::RangeNode, reshape::ReshapeNode,
resize::ResizeNode, slice::SliceNode, split::SplitNode, squeeze::SqueezeNode, sum::SumNode,
tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
matmul_integer::MatMulIntegerNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode,
mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, split::SplitNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::record::PrecisionSettings;
Expand Down Expand Up @@ -107,6 +108,7 @@ pub enum Node<PS: PrecisionSettings> {
LayerNorm(LayerNormNode),
Linear(LinearNode),
Matmul(MatmulNode),
MatmulInteger(MatMulIntegerNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Mean(MeanNode),
Expand Down Expand Up @@ -163,6 +165,7 @@ macro_rules! match_all {
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Node::MatmulInteger(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Mean(node) => $func(node),
Expand Down
213 changes: 213 additions & 0 deletions crates/burn-import/src/burn/node/matmul_integer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use core::cmp::Ordering;

use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone)]
pub struct MatMulIntegerNode {
pub lhs: TensorType,
pub rhs: TensorType,
pub output: TensorType,
pub a_zero_point: Option<TensorType>,
pub b_zero_point: Option<TensorType>,
}

impl MatMulIntegerNode {
pub fn new(
lhs: TensorType,
rhs: TensorType,
output: TensorType,
a_zero_point: Option<TensorType>,
b_zero_point: Option<TensorType>,
) -> Self {
// Validate tensor types - using Int for quantized tensors
if lhs.kind != TensorKind::Int || rhs.kind != TensorKind::Int {
panic!("MatMulInteger is only implemented for integer tensors");
}

// Output is typically an Int32 tensor in ONNX
if output.kind != TensorKind::Int {
panic!("MatMulInteger output must be an integer tensor");
}

// Validate zero points if provided
if let Some(a_zero) = &a_zero_point {
if a_zero.kind != TensorKind::Int {
panic!("A zero point must be an integer tensor");
}
}

if let Some(b_zero) = &b_zero_point {
if b_zero.kind != TensorKind::Int {
panic!("B zero point must be an integer tensor");
}
}

Self {
lhs,
rhs,
output,
a_zero_point,
b_zero_point,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for MatMulIntegerNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
let mut input_types = vec![
Type::Tensor(self.lhs.clone()),
Type::Tensor(self.rhs.clone()),
];
if let Some(a_zero_point) = &self.a_zero_point {
input_types.push(Type::Tensor(a_zero_point.clone()));
}
if let Some(b_zero_point) = &self.b_zero_point {
input_types.push(Type::Tensor(b_zero_point.clone()));
}
input_types
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let lhs = scope.tensor_use_owned(&self.lhs, node_position);
let rhs = scope.tensor_use_owned(&self.rhs, node_position);
let output = &self.output.name;

let a_zero_point = if let Some(a_zero_point) = &self.a_zero_point {
scope.tensor_use_owned(a_zero_point, node_position)
} else {
quote! { 0 }
};

let b_zero_point = if let Some(b_zero_point) = &self.b_zero_point {
scope.tensor_use_owned(b_zero_point, node_position)
} else {
quote! { 0 }
};

let lhs_dim = self.lhs.dim;
let rhs_dim = self.rhs.dim;

// Support broadcasting for missing dimensions
match lhs_dim.cmp(&rhs_dim) {
Ordering::Greater => {
let axes = (0..lhs_dim - rhs_dim)
.map(|i| if i % 2 == 0 { 0 } else { -1 })
.collect::<Vec<i64>>();
let axes = axes.to_tokens();

if rhs_dim == 1 {
let squeeze_dim = lhs_dim - 1;
quote! {
let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)).squeeze(#squeeze_dim);
}
} else {
quote! {
let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point));
}
}
}
Ordering::Less => {
let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens();

if lhs_dim == 1 {
let squeeze_dim = rhs_dim - 2;
quote! {
let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)).squeeze(#squeeze_dim);
}
} else {
quote! {
let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point));
}
}
}
Ordering::Equal => quote! {
let #output = (#lhs - #a_zero_point).matmul((#rhs - #b_zero_point));
Copy link
Member

@laggui laggui Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing! 🙏

Just a quick note: matmul is only available for float tensors right now.

We could probably move this to the numeric operations, but I think we would need to expose the cubecl implementation for numerics (iirc it's only exposed for floats, but actually implemented for numerics).

This is probably gonna be a bit of a blocker for the current PR 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case you haven't stumbled upon the contributor guide, there are a couple of steps detailed to add a new operator: https://burn.dev/contributor-book/guides/onnx-to-burn-conversion-tool.html#adding-new-operators

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to quick and kind review! that was my concern also(the matmul thing).
I'll dig more and ping you when it's ready 👍

(thanks to heads up for the ci failing, btw)

Copy link
Author

@NewBornRustacean NewBornRustacean Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@laggui hi! I've a simple question 🔢

to expand matmul to integer, I was trying to define anther function in burn-tensor/src/tensor/ops/int_tensor.rs, like below:

this is the flow what I understand.

MatMulIntegerNode calls → tensor.matmul() → Backend::int_matmul()
    /// Performs matrix multiplication between two integer tensors.
    ///
    /// For 2D tensors, this computes the standard matrix product: if input tensors
    /// are of shapes (n×m) and (m×p), the output will be of shape (n×p).
    ///
    /// For tensors with more than 2 dimensions, this performs batched matrix multiplication.
    /// If the first tensor has shape (b1,b2,...,bn,n,m) and the second tensor has shape
    /// (b1,b2,...,bn,m,p), the output will have shape (b1,b2,...,bn,n,p).
    ///
    /// Broadcasting is supported for non-matching batch dimensions.
    ///
    /// # Arguments
    ///
    /// * `lhs` - The left-hand side integer tensor.
    /// * `rhs` - The right-hand side integer tensor.
    ///
    /// # Returns
    ///
    /// A new integer tensor containing the result of the matrix multiplication.
    ///
    /// # Panics
    ///
    /// Panics if the tensors are not compatible for matrix multiplication
    /// (i.e., if the number of columns in `lhs` does not equal the number of rows in `rhs`).
    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;

but if I add this fn, need to add this to all the backends, I guess. (which means to me, like shot-gun surgery).
did I understand the situation correctly? if that's the right direction, I'll keep going on. if you have a better design, please let me know!

thanks :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much! The op should be defined in the numeric ops. So the current matmul (public API) would be moved from float to numeric, and then there would be a float_matmul and int_matmul as defined.

As I mentioned above, the matmul for our cubecl backends is only exposed to float I believe, but this could be changed since the actual implementation in cubecl supports numeric iirc (so integers as well). Would have to make a change in cubecl for that though.

I'm not 100% sure if the other backends will support int matmul but this can be added progressively (and some implementations left as todo!() during the review process).

That would be the route to go.

Linking the guide if it can be useful: https://burn.dev/contributor-book/guides/adding-a-new-operation-to-burn.html

This is a bit more job than strictly adding an ONNX op for import support 😅 hence why it was in the "harder" ops to add

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, thanks for this superfast reply!
yeah, exactly it is "harder" jobs 😅 but interesting!

for now, I was trying to add int_matmul and implement int_matmul for NdArrayTensor to simple test.
(following the guide you shared. thanks!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I think tch should also be easy since their matmul supports integers out of the box iirc.

},
}
}

fn into_node(self) -> Node<PS> {
Node::MatmulInteger(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{matmul_integer::MatMulIntegerNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_matmul_integer() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(MatMulIntegerNode::new(
TensorType::new_int("tensor1", 4),
TensorType::new_int("tensor2", 4),
TensorType::new_int("tensor3", 4),
Some(TensorType::new_int("a_zero_point", 1)),
Some(TensorType::new_int("b_zero_point", 1)),
));

graph.register_input_output(
vec![
"tensor1".to_string(),
"tensor2".to_string(),
"a_zero_point".to_string(),
"b_zero_point".to_string(),
],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4, Int>,
tensor2: Tensor<B, 4, Int>,
a_zero_point: Tensor<B, 1, Int>,
b_zero_point: Tensor<B, 1, Int>,
) -> Tensor<B, 4, Int> {
let tensor3 = (tensor1 - a_zero_point).matmul((tensor2 - b_zero_point));
tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod layer_norm;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod matmul_integer;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod mean;
Expand Down
28 changes: 27 additions & 1 deletion crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use burn_tensor::Distribution;

use burn_tensor::ElementConversion;
use core::ops::Range;
use ndarray::IntoDimension;
use std::ops::Mul;
use ndarray::{IntoDimension, Ix2};
use ndarray::Zip;

// Current crate
Expand Down Expand Up @@ -162,6 +163,31 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps
fn int_mul_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
NdArrayMathOps::mul_scalar(lhs, rhs)
}
fn int_matmul(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
let lhs_array = &lhs.array;
let rhs_array = &rhs.array;

// Get dimensions
let m = lhs_array.shape()[0];
let k = lhs_array.shape()[1];
let n = rhs_array.shape()[1];

// Convert dynamic arrays to 2D views
let lhs_2d = lhs_array.view().into_dimensionality::<Ix2>().unwrap();
let rhs_2d = rhs_array.view().into_dimensionality::<Ix2>().unwrap();

// Create the result array with fixed dimensions
let mut result = ndarray::Array2::<I>::zeros((m, n));

// Perform matrix multiplication with fixed dimensions
ndarray::linalg::general_mat_mul(I::one(), &lhs_2d, &rhs_2d, I::zero(), &mut result);

// Convert back to dynamic dimensions
let result_dyn = result.into_dyn();

// Create the tensor
NdArrayTensor::new(result_dyn.into_shared())
}

fn int_div(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::div(lhs, rhs)
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-router/src/ops/op_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
out
}

fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
todo!()
}
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
let client = lhs.client.clone();
let dtype = lhs.dtype;
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
)
}

fn int_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
let lhs_array = &lhs.tensor;
let rhs_array = &rhs.tensor;
let result = lhs_array.dot(rhs_array);
TchTensor::new(result)
}

fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
let copy = false;
let non_blocking = true;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ where
/// # Panics
///
/// If the two tensors don't have a compatible shape.
pub fn matmul(self, other: Self) -> Self {
pub fn float_matmul(self, other: Self) -> Self {
check!(TensorCheck::matmul(&self, &other));
Self::new(TensorPrimitive::Float(B::float_matmul(
self.primitive.tensor(),
Expand Down
Loading
Loading