-
Notifications
You must be signed in to change notification settings - Fork 596
MatMulInteger #2846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
NewBornRustacean
wants to merge
3
commits into
tracel-ai:main
Choose a base branch
from
NewBornRustacean:add-matmulinteger
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
MatMulInteger #2846
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
use core::cmp::Ordering; | ||
|
||
use super::{Node, NodeCodegen}; | ||
use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type}; | ||
use burn::record::PrecisionSettings; | ||
use proc_macro2::TokenStream; | ||
use quote::quote; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct MatMulIntegerNode { | ||
pub lhs: TensorType, | ||
pub rhs: TensorType, | ||
pub output: TensorType, | ||
pub a_zero_point: Option<TensorType>, | ||
pub b_zero_point: Option<TensorType>, | ||
} | ||
|
||
impl MatMulIntegerNode { | ||
pub fn new( | ||
lhs: TensorType, | ||
rhs: TensorType, | ||
output: TensorType, | ||
a_zero_point: Option<TensorType>, | ||
b_zero_point: Option<TensorType>, | ||
) -> Self { | ||
// Validate tensor types - using Int for quantized tensors | ||
if lhs.kind != TensorKind::Int || rhs.kind != TensorKind::Int { | ||
panic!("MatMulInteger is only implemented for integer tensors"); | ||
} | ||
|
||
// Output is typically an Int32 tensor in ONNX | ||
if output.kind != TensorKind::Int { | ||
panic!("MatMulInteger output must be an integer tensor"); | ||
} | ||
|
||
// Validate zero points if provided | ||
if let Some(a_zero) = &a_zero_point { | ||
if a_zero.kind != TensorKind::Int { | ||
panic!("A zero point must be an integer tensor"); | ||
} | ||
} | ||
|
||
if let Some(b_zero) = &b_zero_point { | ||
if b_zero.kind != TensorKind::Int { | ||
panic!("B zero point must be an integer tensor"); | ||
} | ||
} | ||
|
||
Self { | ||
lhs, | ||
rhs, | ||
output, | ||
a_zero_point, | ||
b_zero_point, | ||
} | ||
} | ||
} | ||
|
||
impl<PS: PrecisionSettings> NodeCodegen<PS> for MatMulIntegerNode { | ||
fn output_types(&self) -> Vec<Type> { | ||
vec![Type::Tensor(self.output.clone())] | ||
} | ||
|
||
fn input_types(&self) -> Vec<Type> { | ||
let mut input_types = vec![ | ||
Type::Tensor(self.lhs.clone()), | ||
Type::Tensor(self.rhs.clone()), | ||
]; | ||
if let Some(a_zero_point) = &self.a_zero_point { | ||
input_types.push(Type::Tensor(a_zero_point.clone())); | ||
} | ||
if let Some(b_zero_point) = &self.b_zero_point { | ||
input_types.push(Type::Tensor(b_zero_point.clone())); | ||
} | ||
input_types | ||
} | ||
|
||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { | ||
let lhs = scope.tensor_use_owned(&self.lhs, node_position); | ||
let rhs = scope.tensor_use_owned(&self.rhs, node_position); | ||
let output = &self.output.name; | ||
|
||
let a_zero_point = if let Some(a_zero_point) = &self.a_zero_point { | ||
scope.tensor_use_owned(a_zero_point, node_position) | ||
} else { | ||
quote! { 0 } | ||
}; | ||
|
||
let b_zero_point = if let Some(b_zero_point) = &self.b_zero_point { | ||
scope.tensor_use_owned(b_zero_point, node_position) | ||
} else { | ||
quote! { 0 } | ||
}; | ||
|
||
let lhs_dim = self.lhs.dim; | ||
let rhs_dim = self.rhs.dim; | ||
|
||
// Support broadcasting for missing dimensions | ||
match lhs_dim.cmp(&rhs_dim) { | ||
Ordering::Greater => { | ||
let axes = (0..lhs_dim - rhs_dim) | ||
.map(|i| if i % 2 == 0 { 0 } else { -1 }) | ||
.collect::<Vec<i64>>(); | ||
let axes = axes.to_tokens(); | ||
|
||
if rhs_dim == 1 { | ||
let squeeze_dim = lhs_dim - 1; | ||
quote! { | ||
let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)).squeeze(#squeeze_dim); | ||
} | ||
} else { | ||
quote! { | ||
let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)); | ||
} | ||
} | ||
} | ||
Ordering::Less => { | ||
let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens(); | ||
|
||
if lhs_dim == 1 { | ||
let squeeze_dim = rhs_dim - 2; | ||
quote! { | ||
let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)).squeeze(#squeeze_dim); | ||
} | ||
} else { | ||
quote! { | ||
let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)); | ||
} | ||
} | ||
} | ||
Ordering::Equal => quote! { | ||
let #output = (#lhs - #a_zero_point).matmul((#rhs - #b_zero_point)); | ||
}, | ||
} | ||
} | ||
|
||
fn into_node(self) -> Node<PS> { | ||
Node::MatmulInteger(self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use burn::record::FullPrecisionSettings; | ||
|
||
use super::*; | ||
use crate::burn::{ | ||
graph::BurnGraph, | ||
node::{matmul_integer::MatMulIntegerNode, test::assert_tokens}, | ||
TensorType, | ||
}; | ||
|
||
#[test] | ||
fn test_codegen_matmul_integer() { | ||
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); | ||
|
||
graph.register(MatMulIntegerNode::new( | ||
TensorType::new_int("tensor1", 4), | ||
TensorType::new_int("tensor2", 4), | ||
TensorType::new_int("tensor3", 4), | ||
Some(TensorType::new_int("a_zero_point", 1)), | ||
Some(TensorType::new_int("b_zero_point", 1)), | ||
)); | ||
|
||
graph.register_input_output( | ||
vec![ | ||
"tensor1".to_string(), | ||
"tensor2".to_string(), | ||
"a_zero_point".to_string(), | ||
"b_zero_point".to_string(), | ||
], | ||
vec!["tensor3".to_string()], | ||
); | ||
|
||
let expected = quote! { | ||
use burn::tensor::Int; | ||
use burn::{ | ||
module::Module, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Model<B: Backend> { | ||
phantom: core::marker::PhantomData<B>, | ||
device: burn::module::Ignored<B::Device>, | ||
} | ||
|
||
impl<B: Backend> Model<B> { | ||
#[allow(unused_variables)] | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
phantom: core::marker::PhantomData, | ||
device: burn::module::Ignored(device.clone()), | ||
} | ||
} | ||
|
||
#[allow(clippy::let_and_return, clippy::approx_constant)] | ||
pub fn forward( | ||
&self, | ||
tensor1: Tensor<B, 4, Int>, | ||
tensor2: Tensor<B, 4, Int>, | ||
a_zero_point: Tensor<B, 1, Int>, | ||
b_zero_point: Tensor<B, 1, Int>, | ||
) -> Tensor<B, 4, Int> { | ||
let tensor3 = (tensor1 - a_zero_point).matmul((tensor2 - b_zero_point)); | ||
tensor3 | ||
} | ||
} | ||
}; | ||
|
||
assert_tokens(graph.codegen(), expected); | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing! 🙏
Just a quick note: matmul is only available for float tensors right now.
We could probably move this to the numeric operations, but I think we would need to expose the cubecl implementation for numerics (iirc it's only exposed for floats, but actually implemented for numerics).
This is probably gonna be a bit of a blocker for the current PR 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case you haven't stumbled upon the contributor guide, there are a couple of steps detailed to add a new operator: https://burn.dev/contributor-book/guides/onnx-to-burn-conversion-tool.html#adding-new-operators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to quick and kind review! that was my concern also(the matmul thing).
I'll dig more and ping you when it's ready 👍
(thanks to heads up for the ci failing, btw)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@laggui hi! I've a simple question 🔢
to expand
matmul
to integer, I was trying to define anther function inburn-tensor/src/tensor/ops/int_tensor.rs
, like below:this is the flow what I understand.
but if I add this fn, need to add this to all the backends, I guess. (which means to me, like shot-gun surgery).
did I understand the situation correctly? if that's the right direction, I'll keep going on. if you have a better design, please let me know!
thanks :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty much! The op should be defined in the numeric ops. So the current matmul (public API) would be moved from float to numeric, and then there would be a
float_matmul
andint_matmul
as defined.As I mentioned above, the matmul for our cubecl backends is only exposed to float I believe, but this could be changed since the actual implementation in cubecl supports numeric iirc (so integers as well). Would have to make a change in cubecl for that though.
I'm not 100% sure if the other backends will support int matmul but this can be added progressively (and some implementations left as
todo!()
during the review process).That would be the route to go.
Linking the guide if it can be useful: https://burn.dev/contributor-book/guides/adding-a-new-operation-to-burn.html
This is a bit more job than strictly adding an ONNX op for import support 😅 hence why it was in the "harder" ops to add
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, thanks for this superfast reply!
yeah, exactly it is "harder" jobs 😅 but interesting!
for now, I was trying to add
int_matmul
and implementint_matmul
forNdArrayTensor
to simple test.(following the guide you shared. thanks!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I think tch should also be easy since their matmul supports integers out of the box iirc.