diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index f167bf17a8..9253b09f32 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -92,6 +92,10 @@ impl IntTensorOps for Autodiff { B::int_mul_scalar(lhs, rhs) } + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::int_matmul(lhs, rhs) + } + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_div(lhs, rhs) } diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 54b80593df..e8e8d5c119 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -211,6 +211,10 @@ impl IntTensorOps for Candle()).unwrap()) } + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + todo!() + } + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) } diff --git a/crates/burn-cubecl/src/ops/int_ops.rs b/crates/burn-cubecl/src/ops/int_ops.rs index acf4bc8f3d..e4e285c387 100644 --- a/crates/burn-cubecl/src/ops/int_ops.rs +++ b/crates/burn-cubecl/src/ops/int_ops.rs @@ -179,6 +179,9 @@ where numeric::mul_scalar::(lhs, rhs) } + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + todo!() + } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::div::(lhs, rhs) } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index d5f963036b..76d2c21008 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -958,6 +958,9 @@ impl IntTensorOps for Fusion { out } + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + todo!() + } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(DivOps, B::int_div); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index ca108763b2..95cc681b94 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -9,12 +9,13 @@ use super::{ dropout::DropoutNode, expand::ExpandNode, floor::FloorNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode, - pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, - random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode, - random_uniform_like::RandomUniformLikeNode, range::RangeNode, reshape::ReshapeNode, - resize::ResizeNode, slice::SliceNode, split::SplitNode, squeeze::SqueezeNode, sum::SumNode, - tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + matmul_integer::MatMulIntegerNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, + mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode, + random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, + random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode, + range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, split::SplitNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::record::PrecisionSettings; @@ -107,6 +108,7 @@ pub enum Node { LayerNorm(LayerNormNode), Linear(LinearNode), Matmul(MatmulNode), + MatmulInteger(MatMulIntegerNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), Mean(MeanNode), @@ -163,6 +165,7 @@ macro_rules! match_all { Node::LayerNorm(node) => $func(node), Node::Linear(node) => $func(node), Node::Matmul(node) => $func(node), + Node::MatmulInteger(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), Node::Mean(node) => $func(node), diff --git a/crates/burn-import/src/burn/node/matmul_integer.rs b/crates/burn-import/src/burn/node/matmul_integer.rs new file mode 100644 index 0000000000..d87e9cb79e --- /dev/null +++ b/crates/burn-import/src/burn/node/matmul_integer.rs @@ -0,0 +1,213 @@ +use core::cmp::Ordering; + +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone)] +pub struct MatMulIntegerNode { + pub lhs: TensorType, + pub rhs: TensorType, + pub output: TensorType, + pub a_zero_point: Option, + pub b_zero_point: Option, +} + +impl MatMulIntegerNode { + pub fn new( + lhs: TensorType, + rhs: TensorType, + output: TensorType, + a_zero_point: Option, + b_zero_point: Option, + ) -> Self { + // Validate tensor types - using Int for quantized tensors + if lhs.kind != TensorKind::Int || rhs.kind != TensorKind::Int { + panic!("MatMulInteger is only implemented for integer tensors"); + } + + // Output is typically an Int32 tensor in ONNX + if output.kind != TensorKind::Int { + panic!("MatMulInteger output must be an integer tensor"); + } + + // Validate zero points if provided + if let Some(a_zero) = &a_zero_point { + if a_zero.kind != TensorKind::Int { + panic!("A zero point must be an integer tensor"); + } + } + + if let Some(b_zero) = &b_zero_point { + if b_zero.kind != TensorKind::Int { + panic!("B zero point must be an integer tensor"); + } + } + + Self { + lhs, + rhs, + output, + a_zero_point, + b_zero_point, + } + } +} + +impl NodeCodegen for MatMulIntegerNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + let mut input_types = vec![ + Type::Tensor(self.lhs.clone()), + Type::Tensor(self.rhs.clone()), + ]; + if let Some(a_zero_point) = &self.a_zero_point { + input_types.push(Type::Tensor(a_zero_point.clone())); + } + if let Some(b_zero_point) = &self.b_zero_point { + input_types.push(Type::Tensor(b_zero_point.clone())); + } + input_types + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let lhs = scope.tensor_use_owned(&self.lhs, node_position); + let rhs = scope.tensor_use_owned(&self.rhs, node_position); + let output = &self.output.name; + + let a_zero_point = if let Some(a_zero_point) = &self.a_zero_point { + scope.tensor_use_owned(a_zero_point, node_position) + } else { + quote! { 0 } + }; + + let b_zero_point = if let Some(b_zero_point) = &self.b_zero_point { + scope.tensor_use_owned(b_zero_point, node_position) + } else { + quote! { 0 } + }; + + let lhs_dim = self.lhs.dim; + let rhs_dim = self.rhs.dim; + + // Support broadcasting for missing dimensions + match lhs_dim.cmp(&rhs_dim) { + Ordering::Greater => { + let axes = (0..lhs_dim - rhs_dim) + .map(|i| if i % 2 == 0 { 0 } else { -1 }) + .collect::>(); + let axes = axes.to_tokens(); + + if rhs_dim == 1 { + let squeeze_dim = lhs_dim - 1; + quote! { + let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)).squeeze(#squeeze_dim); + } + } else { + quote! { + let #output = (#lhs - #a_zero_point).matmul((#rhs.unsqueeze_dims(&#axes) - #b_zero_point)); + } + } + } + Ordering::Less => { + let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens(); + + if lhs_dim == 1 { + let squeeze_dim = rhs_dim - 2; + quote! { + let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)).squeeze(#squeeze_dim); + } + } else { + quote! { + let #output = (#lhs.unsqueeze_dims(&#axes) - #a_zero_point).matmul((#rhs - #b_zero_point)); + } + } + } + Ordering::Equal => quote! { + let #output = (#lhs - #a_zero_point).matmul((#rhs - #b_zero_point)); + }, + } + } + + fn into_node(self) -> Node { + Node::MatmulInteger(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{matmul_integer::MatMulIntegerNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_matmul_integer() { + let mut graph = BurnGraph::::default(); + + graph.register(MatMulIntegerNode::new( + TensorType::new_int("tensor1", 4), + TensorType::new_int("tensor2", 4), + TensorType::new_int("tensor3", 4), + Some(TensorType::new_int("a_zero_point", 1)), + Some(TensorType::new_int("b_zero_point", 1)), + )); + + graph.register_input_output( + vec![ + "tensor1".to_string(), + "tensor2".to_string(), + "a_zero_point".to_string(), + "b_zero_point".to_string(), + ], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor, + a_zero_point: Tensor, + b_zero_point: Tensor, + ) -> Tensor { + let tensor3 = (tensor1 - a_zero_point).matmul((tensor2 - b_zero_point)); + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 39154ed979..b0da4bb993 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod layer_norm; pub(crate) mod linear; pub(crate) mod mask_where; pub(crate) mod matmul; +pub(crate) mod matmul_integer; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; pub(crate) mod mean; diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index d4449c0262..0e4a494104 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -7,7 +7,8 @@ use burn_tensor::Distribution; use burn_tensor::ElementConversion; use core::ops::Range; -use ndarray::IntoDimension; +use std::ops::Mul; +use ndarray::{IntoDimension, Ix2}; use ndarray::Zip; // Current crate @@ -162,6 +163,31 @@ impl IntTensorOps fn int_mul_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { NdArrayMathOps::mul_scalar(lhs, rhs) } + fn int_matmul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + let lhs_array = &lhs.array; + let rhs_array = &rhs.array; + + // Get dimensions + let m = lhs_array.shape()[0]; + let k = lhs_array.shape()[1]; + let n = rhs_array.shape()[1]; + + // Convert dynamic arrays to 2D views + let lhs_2d = lhs_array.view().into_dimensionality::().unwrap(); + let rhs_2d = rhs_array.view().into_dimensionality::().unwrap(); + + // Create the result array with fixed dimensions + let mut result = ndarray::Array2::::zeros((m, n)); + + // Perform matrix multiplication with fixed dimensions + ndarray::linalg::general_mat_mul(I::one(), &lhs_2d, &rhs_2d, I::zero(), &mut result); + + // Convert back to dynamic dimensions + let result_dyn = result.into_dyn(); + + // Create the tensor + NdArrayTensor::new(result_dyn.into_shared()) + } fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::div(lhs, rhs) diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index fc1eb40cdd..a8725a34c9 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -609,6 +609,9 @@ impl IntTensorOps for BackendRouter { out } + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + todo!() + } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let dtype = lhs.dtype; diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index c0400c8e80..2c6dec1d39 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -135,6 +135,13 @@ impl IntTensorOps for LibTorch { ) } + fn int_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + let lhs_array = &lhs.tensor; + let rhs_array = &rhs.tensor; + let result = lhs_array.dot(rhs_array); + TchTensor::new(result) + } + fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { let copy = false; let non_blocking = true; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index b50d0d0596..b317dcb053 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -178,7 +178,7 @@ where /// # Panics /// /// If the two tensors don't have a compatible shape. - pub fn matmul(self, other: Self) -> Self { + pub fn float_matmul(self, other: Self) -> Self { check!(TensorCheck::matmul(&self, &other)); Self::new(TensorPrimitive::Float(B::float_matmul( self.primitive.tensor(), diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 5d65b68ceb..23edb174d1 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -154,4 +154,42 @@ where pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self { Self::new(B::bitwise_right_shift_scalar(self.primitive, other)) } + + /// Applies the matrix multiplication operation for integer tensors. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors don't have a compatible shape. + pub fn int_matmul(self, other: Self) -> Self { + // Manual dimension checking for matrix multiplication + // we can rewrite check!() for int in numeric.rs + if D < 2 { + panic!("Matrix multiplication requires tensors with at least 2 dimensions"); + } + + let shape_lhs = self.shape(); + let shape_rhs = other.shape(); + + let dim_lhs = shape_lhs.dims[D - 1]; + let dim_rhs = shape_rhs.dims[D - 2]; + + if dim_lhs != dim_rhs { + panic!( + "The inner dimension of matmul should be the same, but got {} and {}. Lhs shape {:?}, rhs shape {:?}.", + dim_lhs, dim_rhs, shape_lhs.dims, shape_rhs.dims + ); + } + + // Check device compatibility + if self.device() != other.device() { + panic!("Tensors must be on the same device for matmul operation"); + } + + Self::new(B::int_matmul( + self.primitive, + other.primitive, + )) + } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 9a39601d42..cd742657e4 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -282,6 +282,33 @@ where Self::new(K::mul_scalar::(self.primitive, other)) } + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors don't have a compatible shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], &device); + /// let tensor = tensor1.matmul(tensor2); + /// println!("{tensor}"); + /// // [[16.0, 16.0], [74.0, 98.0]] + /// } + /// ``` + pub fn matmul(self, other: Self) -> Self { + Self::new(K::matmul(self.primitive, other.primitive)) + } + /// Switch sign of each element in the tensor. /// /// `y = -x` @@ -2463,6 +2490,28 @@ where /// which is more high-level and designed for public use. fn neg(tensor: Self::Primitive) -> Self::Primitive; + /// Applies the matrix multiplication operation. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of the matrix multiplication. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For matrix multiplication of tensors, users should prefer the [Tensor::matmul](Tensor::matmul) function, + /// which is more high-level and designed for public use. + fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; + + /// Returns the signs of the elements of a tensor. /// /// # Arguments @@ -3495,6 +3544,9 @@ impl Numeric for Int { fn mul_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { B::int_mul_scalar(lhs, rhs.elem()) } + fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + B::int_matmul(lhs, rhs) + } fn neg(tensor: Self::Primitive) -> Self::Primitive { B::int_neg(tensor) } @@ -3829,6 +3881,17 @@ impl Numeric for Float { } } } + fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_matmul(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_matmul(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } + } fn neg(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)), diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 81b73eb2dd..c0b6897393 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1186,6 +1186,32 @@ pub trait IntTensorOps { argsort::(tensor, dim, descending) } + /// Performs matrix multiplication between two integer tensors. + /// + /// For 2D tensors, this computes the standard matrix product: if input tensors + /// are of shapes (n×m) and (m×p), the output will be of shape (n×p). + /// + /// For tensors with more than 2 dimensions, this performs batched matrix multiplication. + /// If the first tensor has shape (b1,b2,...,bn,n,m) and the second tensor has shape + /// (b1,b2,...,bn,m,p), the output will have shape (b1,b2,...,bn,n,p). + /// + /// Broadcasting is supported for non-matching batch dimensions. + /// + /// # Arguments + /// + /// * `lhs` - The left-hand side integer tensor. + /// * `rhs` - The right-hand side integer tensor. + /// + /// # Returns + /// + /// A new integer tensor containing the result of the matrix multiplication. + /// + /// # Panics + /// + /// Panics if the tensors are not compatible for matrix multiplication + /// (i.e., if the number of columns in `lhs` does not equal the number of rows in `rhs`). + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + /// Bitwise AND operation for Int Tensors fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor; diff --git a/crates/burn-tensor/src/tests/ops/matmul.rs b/crates/burn-tensor/src/tests/ops/matmul.rs index 0e666c8644..1a5428a48b 100644 --- a/crates/burn-tensor/src/tests/ops/matmul.rs +++ b/crates/burn-tensor/src/tests/ops/matmul.rs @@ -217,4 +217,162 @@ mod tests { tensor_3.into_data().assert_eq(&expected, false); } + + /// tests for the matmul-int + #[test] + fn test_int_matmul_d2() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<2>::from_ints([[1, 7], [2, 3], [1, 5]], &device); + let tensor_2 = TestTensorInt::from_ints([[4, 7, 5], [2, 3, 5]], &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[18, 28, 40], [14, 23, 25], [14, 22, 30]]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + fn test_int_matmul_d3() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<3>::from_ints([[[1, 7], [2, 3]]], &device); + let tensor_2 = TestTensorInt::from_ints([[[4, 7], [2, 3]]], &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[[18, 28], [14, 23]]]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + fn test_int_matmul_broadcast_1() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<3>::from_ints([[[1, 7], [2, 3]]], &device); + let tensor_2 = TestTensorInt::from_ints( + [[[4, 7], [2, 3]], [[2, 5], [6, 3]]], + &device, + ); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = + TensorData::from([[[18, 28], [14, 23]], [[44, 26], [22, 19]]]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + fn test_int_matmul_broadcast_4d() { + let device = Default::default(); + // [2, 1, 2, 2] + let tensor_1 = TestTensorInt::<4>::from_ints( + [[[[1, 7], [2, 3]]], [[[2, 5], [6, 3]]]], + &device, + ); + // [1, 2, 2, 2] + let tensor_2 = TestTensorInt::from_ints( + [[[[9, 8], [1, 4]], [[2, 7], [3, 5]]]], + &device, + ); + + // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2] + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([ + [[[16, 36], [21, 28]], [[23, 42], [13, 29]]], + [[[23, 36], [57, 60]], [[19, 39], [21, 57]]], + ]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + fn test_int_matmul_simple_1() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<2>::from_ints([[5, 14], [14, 50]], &device); + let tensor_2 = TestTensorInt::from_ints([[3, 4, 5], [0, 1, 2]], &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[15, 34, 53], [42, 106, 170]]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + fn test_int_matmul_4_3() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<2>::from_ints( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + &device, + ); + let tensor_2 = TestTensorInt::from_ints( + [[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]], + &device, + ); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[56, 62, 68], [152, 174, 196], [248, 286, 324]]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + fn test_int_matmul_trivial() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<1>::arange(0..16, &device).reshape([4, 4]); + let tensor_3 = tensor_1.clone().matmul(tensor_1); + + tensor_3.into_data().assert_eq( + &TensorData::from([ + [56, 62, 68, 74], + [152, 174, 196, 218], + [248, 286, 324, 362], + [344, 398, 452, 506], + ]), + false, + ); + } + + #[test] + fn test_int_matmul_trivial_transposed() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<1>::arange(0..16, &device).reshape([4, 4]); + let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); + + tensor_3.into_data().assert_eq( + &TensorData::from([ + [14, 38, 62, 86], + [38, 126, 214, 302], + [62, 214, 366, 518], + [86, 302, 518, 734], + ]), + false, + ); + } + + #[test] + fn test_int_matmul_simple_2() { + let device = Default::default(); + let tensor_1 = TestTensorInt::<2>::from_ints([[1, 2, 3, 4]], &device); + let tensor_2 = TestTensorInt::from_ints([[3], [4], [5], [6]], &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[50]]); + + tensor_3.into_data().assert_eq(&expected, false); + } + + #[test] + #[should_panic] + fn should_panic_when_inner_dimensions_are_not_equal_int() { + let device = Default::default(); + let tensor_1 = + TestTensorInt::<2>::from_ints([[3, 3], [4, 4], [5, 5], [6, 6]], &device); + let tensor_2 = TestTensorInt::from_ints( + [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], + &device, + ); + + // This should panic as dimensions don't match + let _ = tensor_1.matmul(tensor_2); + } } + +