Skip to content

Commit 7869d09

Browse files
Max ManainenMax Manainen
Max Manainen
authored and
Max Manainen
committed
broken parsing example
1 parent 9b1bf3e commit 7869d09

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s
2+
// other passes like convert-linalg-to-loops have the same problems
3+
4+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
5+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
6+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
7+
module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
8+
func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
9+
%0 = tensor.empty() : tensor<2x4xf32>
10+
%cst = arith.constant 0.000000e+00 : f32
11+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32>
12+
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) {
13+
^bb0(%in: f32, %in_0: f32, %out: f32):
14+
%3 = arith.mulf %in, %in_0 : f32
15+
%4 = arith.addf %out, %3 : f32
16+
linalg.yield %4 : f32
17+
} -> tensor<2x4xf32>
18+
return %2 : tensor<2x4xf32>
19+
}
20+
}

0 commit comments

Comments
 (0)