Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Jax tensor donation #3224

Closed
wants to merge 22 commits into from
Prev Previous commit
Next Next commit
fixes and simplifications
Max Manainen authored and Max Manainen committed Oct 15, 2024
commit 9f6ae3891febdb6136be31a1775250a75d00fb52
64 changes: 33 additions & 31 deletions tests/filecheck/transforms/jax-use-donated-arguments.mlir
Original file line number Diff line number Diff line change
@@ -2,53 +2,55 @@

builtin.module {
func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) {
%0 = tensor.empty() : tensor<2x4xf32>
%cst = arith.constant 0.000000e+00 : f32
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32>
return %1 : tensor<2x4xf32>
%res = "test.op"() : () -> tensor<2x4xf32>
return %res : tensor<2x4xf32>
}
}

// CHECK: builtin.module {
// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> {
// CHECK-NEXT: %0 = tensor.empty() : tensor<2x4xf32>
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK-NEXT: %2 = bufferization.materialize_in_destination %1 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK-NEXT: func.return %2 : tensor<2x4xf32>
// CHECK-NEXT: %res = "test.op"() : () -> tensor<2x4xf32>
// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK-NEXT: func.return %0 : tensor<2x4xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }

builtin.module {
func.func public @main(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32>, %arg2: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) {
%cst = arith.constant 0.000000e+00 : f32

%0 = tensor.empty() : tensor<2x3xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32>

%2 = tensor.empty() : tensor<2x3xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32>
func.func public @main(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
%res1 = "test.op"() : () -> tensor<2x3xf32>
%res2 = "test.op"() : () -> tensor<2x3xf32>
return %res1, %res2 : tensor<2x3xf32>, tensor<2x3xf32>
}
}

%4 = tensor.empty() : tensor<4x5xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32>
// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res2 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: func.return %0, %1 : tensor<2x3xf32>, tensor<2x3xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }

return %1, %3, %5 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>
builtin.module {
func.func public @main(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) {
%res1 = "test.op"() : () -> tensor<2x3xf32>
%res2 = "test.op"() : () -> tensor<2x3xf32>
%res3 = "test.op"() : () -> tensor<4x5xf32>
return %res1, %res2, %res3 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>
}
}

// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32>, %arg2 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) {
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %0 = tensor.empty() : tensor<2x3xf32>
// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %2 = tensor.empty() : tensor<2x3xf32>
// CHECK-NEXT: %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %4 = tensor.empty() : tensor<4x5xf32>
// CHECK-NEXT: %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32>
// CHECK-NEXT: %6 = bufferization.materialize_in_destination %1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %7 = bufferization.materialize_in_destination %5 in %arg2 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
// CHECK-NEXT: func.return %6, %3, %7 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>
// CHECK-NEXT: func.func public @main(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) {
// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %res3 = "test.op"() : () -> tensor<4x5xf32>
// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res3 in %arg0 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
// CHECK-NEXT: func.return %0, %res2, %1 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }

13 changes: 0 additions & 13 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
@@ -800,19 +800,6 @@ def get_shape(self) -> tuple[int, ...]:
def get_element_type(self) -> AttributeCovT:
return self.element_type

def is_same_type_with(self, other_tensor: TensorType[Attribute]) -> bool:
current_shape = list(self.shape)
other_shape = list(other_tensor.shape)
if len(current_shape) != len(other_shape):
return False

return (
len(list(filter(lambda x: x[0] != x[1], zip(current_shape, other_shape))))
== 0
and self.element_type == other_tensor.element_type
and self.encoding == other_tensor.encoding
)


AnyTensorType: TypeAlias = TensorType[Attribute]
AnyTensorTypeConstr = BaseAttr[TensorType[Attribute]](TensorType)
13 changes: 7 additions & 6 deletions xdsl/transforms/jax_use_donated_arguments.py
Original file line number Diff line number Diff line change
@@ -17,10 +17,6 @@
from xdsl.utils.exceptions import VerifyException


def make_materialize_op(source: SSAValue, dest: SSAValue) -> MaterializeInDestination:
return MaterializeInDestination(operands=[source, dest], result_types=[source.type])


@dataclass
class SubstituteDonatedTensors(RewritePattern):
@op_type_rewrite_pattern
@@ -42,8 +38,13 @@ def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /):
new_ops: list[Operation] = []
for output in op.arguments:
for i, arg in enumerate(donated_inputs):
if getattr(arg, "type").is_same_type_with(getattr(output, "type")):
new_ops.append(make_materialize_op(output, donated_inputs.pop(i)))
if arg.type == output.type:
new_ops.append(
MaterializeInDestination(
operands=[output, donated_inputs.pop(i)],
result_types=[output.type],
)
)
value_mapper[output] = new_ops[-1].results[0]
break