diff --git a/build_tools/riscv/mips_matmul_test.mlir b/build_tools/riscv/mips_matmul_test.mlir new file mode 100644 index 000000000000..f3e4d09c83ca --- /dev/null +++ b/build_tools/riscv/mips_matmul_test.mlir @@ -0,0 +1,101 @@ +// mips_matmul_test.mlir +// +// End-to-end test inputs for the MIPS matmul kernel pipeline. +// +// Each function exercises torch.aten.mm, which is intercepted by +// ConvertTorchToMIPSPass and rewritten as mips.matmul. The op is then +// eliminated during One-Shot Bufferize: MIPSBufferizableOpInterface +// decomposes the 2-D memrefs and emits a direct func.call to the +// hand-tuned C kernel: +// +// torch.aten.mm +// → mips.matmul (ConvertTorchToMIPSPass) +// → flow.dispatch(...) (IREE dispatch formation) +// → func.call @my_matmul_kernel (MIPSBufferizableOpInterface) +// → ELF inside .vmfb (iree-compile LLVMCPU backend) +// +// Usage: +// bash build_tools/riscv/rvv_qemu_workflow_static.sh -- static (.o baked into vmfb) +// bash build_tools/riscv/rvv_qemu_workflow_dynamic.sh -- dynamic (.so plugin at runtime) + +module { + // ── Test 1: 4×4 identity × data → passthrough ──────────────────────────── + // Verifies that A=I leaves B unchanged; a simple correctness smoke-test. + // + // A = identity(4×4), B = [[1..4],[5..8],[9..12],[13..16]] + // Expected: result = B + func.func @matmul_4x4( + %A : !torch.vtensor<[4,4],f32>, + %B : !torch.vtensor<[4,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> + } + + // ── Test 2: 2×3 × 3×2 → 2×2 (non-square, reduced K dimension) ────────── + // Verifies M≠N≠K path through the kernel (inner loop trip-count < vlen). + // + // A = [[1,2,3],[4,5,6]], B = [[1,0],[0,1],[1,0]] + // Expected: [[1+0+3, 0+2+0],[4+0+6, 0+5+0]] = [[4,2],[10,5]] + func.func @matmul_2x3x2( + %A : !torch.vtensor<[2,3],f32>, + %B : !torch.vtensor<[3,2],f32>) + -> !torch.vtensor<[2,2],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,2],f32> + -> !torch.vtensor<[2,2],f32> + return %0 : !torch.vtensor<[2,2],f32> + } + + // ── Test 3: 8×8 × 8×8 → 8×8 (exercises multi-vector-register tiling) ─── + // With vlen=512 and LMUL=m4, N=8 fits in a single VL group. This test + // stresses the vectorized inner loop and accumulation across K=8 steps. + // + // A = upper-triangular ones (row i has ones in columns 0..i). + // B = identity(8×8). + // Expected: A*I = A — result is upper-triangular ones. + // + // A row layout (8×8): + // row 0: [1,0,0,0,0,0,0,0] + // row 1: [1,1,0,0,0,0,0,0] + // row 2: [1,1,1,0,0,0,0,0] + // ... + // row 7: [1,1,1,1,1,1,1,1] + func.func @matmul_8x8( + %A : !torch.vtensor<[8,8],f32>, + %B : !torch.vtensor<[8,8],f32>) + -> !torch.vtensor<[8,8],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[8,8],f32>, !torch.vtensor<[8,8],f32> + -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Expected outputs (iree-run-module) +// ───────────────────────────────────────────────────────────────────────────── +// +// matmul_4x4 A=identity(4x4), B=[1..16 row-major]: +// result[0]: 4x4xf32=[1 2 3 4][5 6 7 8][9 10 11 12][13 14 15 16] +// +// matmul_2x3x2 A=[[1,2,3],[4,5,6]], B=[[1,0],[0,1],[1,0]]: +// result[0]: 2x2xf32=[4 2][10 5] +// +// matmul_8x8 A=upper-triangular-ones(8x8), B=identity(8x8): +// result[0]: 8x8xf32= +// [1 0 0 0 0 0 0 0] +// [1 1 0 0 0 0 0 0] +// [1 1 1 0 0 0 0 0] +// [1 1 1 1 0 0 0 0] +// [1 1 1 1 1 0 0 0] +// [1 1 1 1 1 1 0 0] +// [1 1 1 1 1 1 1 0] +// [1 1 1 1 1 1 1 1] +// +// iree-run-module invocation for matmul_8x8: +// --function=matmul_8x8 +// "--input=8x8xf32=1,0,0,0,0,0,0,0, 1,1,0,0,0,0,0,0, 1,1,1,0,0,0,0,0, 1,1,1,1,0,0,0,0, 1,1,1,1,1,0,0,0, 1,1,1,1,1,1,0,0, 1,1,1,1,1,1,1,0, 1,1,1,1,1,1,1,1" +// "--input=8x8xf32=1,0,0,0,0,0,0,0, 0,1,0,0,0,0,0,0, 0,0,1,0,0,0,0,0, 0,0,0,1,0,0,0,0, 0,0,0,0,1,0,0,0, 0,0,0,0,0,1,0,0, 0,0,0,0,0,0,1,0, 0,0,0,0,0,0,0,1" diff --git a/build_tools/riscv/rvv_qemu_workflow_dynamic.sh b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh new file mode 100755 index 000000000000..0d906813071c --- /dev/null +++ b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh @@ -0,0 +1,206 @@ +#!/usr/bin/env bash +# rvv_qemu_workflow_dynamic.sh +# +# End-to-end MIPS matmul pipeline — DYNAMIC plugin loading. +# +# The RVV kernel is compiled into a shared library (.so) that is loaded at +# runtime via --executable_plugin. No custom linker wrapper is needed at +# iree-compile time. +# +# Pipeline: +# mips_matmul_test.mlir +# ─[iree-opt torch-to-iree{use-mips-matmul=true}]─► flow.mlir +# ─[clang --target=riscv64 -shared]──────────────► librvv_matmul.so +# ─[iree-compile --iree-llvmcpu-link-embedded=false]► matmul.vmfb +# ─[qemu-riscv64 iree-run-module --executable_plugin]► result +# +# Usage: +# bash rvv_qemu_workflow_dynamic.sh # RISC-V QEMU, vlen=512 +# bash rvv_qemu_workflow_dynamic.sh --host # x86 host (scalar fallback) +# bash rvv_qemu_workflow_dynamic.sh --vlen 256 # QEMU with vlen=256 + +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration +# ───────────────────────────────────────────────────────────────────────────── +IREE_SRC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +WORK_DIR="${HOME}/MLIR_Work/mips" +HOST_BUILD="${WORK_DIR}/iree-build" +HOST_INSTALL="${HOST_BUILD}/install" +RISCV_BUILD="${WORK_DIR}/iree-build-riscv" +OUT_DIR="${WORK_DIR}/out/dynamic" + +IREE_OPT="${HOST_INSTALL}/bin/iree-opt" +IREE_COMPILE="${HOST_INSTALL}/bin/iree-compile" +HOST_RUN="${HOST_INSTALL}/bin/iree-run-module" +RISCV_RUN="${RISCV_BUILD}/install/bin/iree-run-module" +QEMU="${HOME}/local/bin/qemu-riscv64" +SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" + +CLANG="${HOME}/miniforge3/bin/clang" +LLD="${HOME}/miniforge3/bin/ld.lld" +CLANG_INC="${HOME}/miniforge3/lib/clang/18/include" + +KERNEL_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_kernel.c" +PLUGIN_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_plugin.c" +TEST_MLIR="${IREE_SRC}/build_tools/riscv/mips_matmul_test.mlir" + +# Rocky 8's libstdc++ is too old; conda has GLIBCXX 3.4.29+. +export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── +HOST_MODE=0 +VLEN=512 +while [[ $# -gt 0 ]]; do + case "$1" in + --host) HOST_MODE=1 ;; + --vlen) shift; VLEN="$1" ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac + shift +done + +mkdir -p "${OUT_DIR}" + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── +section() { echo ""; echo "══[ $* ]══════════════════════════════════════════════════"; } +ok() { echo " [ok] $*"; } +run_qemu() { + local vlen="$1"; shift + "${QEMU}" -cpu "rv64,v=true,vlen=${vlen},elen=64,vext_spec=v1.0" \ + -L "${SYSROOT}" "${RISCV_RUN}" "$@" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 1: torch → IREE flow IR +# ───────────────────────────────────────────────────────────────────────────── +section "Step 1: torch → IREE flow IR" + +"${IREE_OPT}" \ + --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ + "${TEST_MLIR}" -o "${OUT_DIR}/flow.mlir" +ok "${OUT_DIR}/flow.mlir" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 2: Cross-compile kernel + plugin → shared library +# +# matmul_kernel.c — compute logic (no IREE headers) +# matmul_plugin.c — IREE HAL executable plugin interface +# Both compiled together into a single -fPIC -shared .so. +# ───────────────────────────────────────────────────────────────────────────── +section "Step 2: Compile matmul_kernel.c + matmul_plugin.c → .so" + +PLUGIN_SO="${OUT_DIR}/librvv_matmul.so" + +if [[ "${HOST_MODE}" == "1" ]]; then + "${CLANG}" --target=x86_64-linux-gnu \ + -O2 -fPIC -shared \ + -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" "${PLUGIN_SRC}" -o "${PLUGIN_SO}" + ok "x86 scalar plugin: ${PLUGIN_SO}" +else + "${CLANG}" --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -fPIC -shared -nostdinc -nostdlib \ + -isystem "${CLANG_INC}" \ + -fuse-ld="${LLD}" \ + -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" "${PLUGIN_SRC}" -o "${PLUGIN_SO}" + ok "RISC-V RVV plugin: ${PLUGIN_SO} ($(file -b "${PLUGIN_SO}" | cut -d, -f1))" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 3: iree-compile → .vmfb (kernel resolved at runtime via plugin) +# +# --iree-llvmcpu-link-embedded=false — host-ABI shared object, not embedded ELF +# No --iree-mips-static-embedding — my_matmul_kernel is a HAL import entry +# ───────────────────────────────────────────────────────────────────────────── +section "Step 3: iree-compile → .vmfb (dynamic)" + +if [[ "${HOST_MODE}" == "1" ]]; then + RISCV_FLAGS=() +else + RISCV_FLAGS=( + "--iree-llvmcpu-target-triple=riscv64-linux-gnu" + "--iree-llvmcpu-target-abi=lp64d" + "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+d,+c,+zvl512b,+v" + "--riscv-v-fixed-length-vector-lmul-max=8" + ) +fi + +VMFB="${OUT_DIR}/matmul_dynamic.vmfb" +"${IREE_COMPILE}" \ + --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-link-embedded=false \ + "${RISCV_FLAGS[@]}" \ + "${OUT_DIR}/flow.mlir" -o "${VMFB}" +ok "${VMFB} ($(du -sh "${VMFB}" | cut -f1))" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 4: Verify kernel appears as an import (unresolved symbol) in the vmfb +# ───────────────────────────────────────────────────────────────────────────── +section "Step 4: Verify dynamic import in vmfb" + +ELF_OFFSET=$(grep -boa $'\x7fELF' "${VMFB}" 2>/dev/null | head -1 | cut -d: -f1 || true) +if [[ -n "${ELF_OFFSET}" ]]; then + dd if="${VMFB}" bs=1 skip="${ELF_OFFSET}" 2>/dev/null > "${OUT_DIR}/dispatch.elf" + python3 - "${OUT_DIR}/dispatch.elf" << 'PYEOF' +import sys +data = open(sys.argv[1], 'rb').read() +idx = data.find(b'my_matmul_kernel') +rvv = sum(1 for i in range(0, len(data)-3, 4) if data[i] & 0x7f == 0x57) +if idx != -1: + print(f" [ok] 'my_matmul_kernel' at offset {idx} (import table entry — kernel lives in .so)") +else: + print(" [warn] 'my_matmul_kernel' not found in dispatch ELF") +PYEOF +else + echo " [warn] No ELF found in vmfb" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 5: Run (kernel loaded from .so via --executable_plugin) +# ───────────────────────────────────────────────────────────────────────────── +section "Step 5: Run (--executable_plugin=${PLUGIN_SO})" + +MATMUL_ARGS=( + --module="${VMFB}" + --executable_plugin="${PLUGIN_SO}" + --function="matmul_4x4" + "--input=4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" + "--input=4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" +) + +if [[ "${HOST_MODE}" == "1" ]]; then + echo " Running on x86 host (scalar fallback)..." + "${HOST_RUN}" "${MATMUL_ARGS[@]}" +else + echo " Running under QEMU vlen=${VLEN}..." + run_qemu "${VLEN}" "${MATMUL_ARGS[@]}" + + echo "" + echo " VLEN sweep:" + for V in 128 256 512; do + printf " vlen=%-4s " "${V}:" + run_qemu "${V}" "${MATMUL_ARGS[@]}" 2>&1 | grep "4x4xf32" || echo "(no output)" + done + echo " Note: vlen=128 may produce zeros — vmfb compiled with +zvl512b" +fi + +echo "" +echo " Expected: 4x4xf32=[1 2 3 4][5 6 7 8][9 10 11 12][13 14 15 16]" + +# ───────────────────────────────────────────────────────────────────────────── +# Summary +# ───────────────────────────────────────────────────────────────────────────── +echo "" +echo "════════════════════════════════════════════════════════════" +echo " DONE — Dynamic plugin verified." +echo " Artifacts in ${OUT_DIR}/" +echo " librvv_matmul.so — plugin loaded at runtime" +echo " matmul_dynamic.vmfb — vmfb with HAL import (needs plugin)" +echo "════════════════════════════════════════════════════════════" diff --git a/build_tools/riscv/rvv_qemu_workflow_static.sh b/build_tools/riscv/rvv_qemu_workflow_static.sh new file mode 100644 index 000000000000..c40d531b450b --- /dev/null +++ b/build_tools/riscv/rvv_qemu_workflow_static.sh @@ -0,0 +1,220 @@ +#!/usr/bin/env bash +# rvv_qemu_workflow_static.sh +# +# End-to-end MIPS matmul pipeline — STATIC kernel embedding. +# +# The RVV kernel (.o) is baked into the dispatch ELF inside the .vmfb at +# iree-compile time via a custom lld wrapper. No plugin .so is needed at +# runtime. +# +# Pipeline: +# mips_matmul_test.mlir +# ─[iree-opt torch-to-iree{use-mips-matmul=true}]─► flow.mlir +# ─[clang --target=riscv64]──────────────────────► matmul_kernel_riscv.o +# ─[lld_wrapper.sh] (appends .o to every dispatch link) +# ─[iree-compile --iree-mips-static-embedding]────► matmul.vmfb +# ─[qemu-riscv64 iree-run-module]────────────────► result (no --executable_plugin) +# +# Usage: +# bash rvv_qemu_workflow_static.sh # RISC-V QEMU, vlen=512 +# bash rvv_qemu_workflow_static.sh --host # x86 host (scalar fallback) +# bash rvv_qemu_workflow_static.sh --vlen 256 # QEMU with vlen=256 + +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration +# ───────────────────────────────────────────────────────────────────────────── +IREE_SRC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +WORK_DIR="${HOME}/MLIR_Work/mips" +HOST_BUILD="${WORK_DIR}/iree-build" +HOST_INSTALL="${HOST_BUILD}/install" +RISCV_BUILD="${WORK_DIR}/iree-build-riscv" +OUT_DIR="${WORK_DIR}/out/static" + +IREE_OPT="${HOST_INSTALL}/bin/iree-opt" +IREE_COMPILE="${HOST_INSTALL}/bin/iree-compile" +HOST_RUN="${HOST_INSTALL}/bin/iree-run-module" +RISCV_RUN="${RISCV_BUILD}/install/bin/iree-run-module" +QEMU="${HOME}/local/bin/qemu-riscv64" +SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" + +CLANG="${HOME}/miniforge3/bin/clang" +LLD="${HOME}/miniforge3/bin/ld.lld" +CLANG_INC="${HOME}/miniforge3/lib/clang/18/include" + +KERNEL_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_kernel.c" +TEST_MLIR="${IREE_SRC}/build_tools/riscv/mips_matmul_test.mlir" + +# Rocky 8's libstdc++ is too old; conda has GLIBCXX 3.4.29+. +export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── +HOST_MODE=0 +VLEN=512 +while [[ $# -gt 0 ]]; do + case "$1" in + --host) HOST_MODE=1 ;; + --vlen) shift; VLEN="$1" ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac + shift +done + +mkdir -p "${OUT_DIR}" + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── +section() { echo ""; echo "══[ $* ]══════════════════════════════════════════════════"; } +ok() { echo " [ok] $*"; } +run_qemu() { + local vlen="$1"; shift + "${QEMU}" -cpu "rv64,v=true,vlen=${vlen},elen=64,vext_spec=v1.0" \ + -L "${SYSROOT}" "${RISCV_RUN}" "$@" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 1: torch → IREE flow IR +# ───────────────────────────────────────────────────────────────────────────── +section "Step 1: torch → IREE flow IR" + +"${IREE_OPT}" \ + --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ + "${TEST_MLIR}" -o "${OUT_DIR}/flow.mlir" +ok "${OUT_DIR}/flow.mlir" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 2: Cross-compile RVV kernel → RISC-V relocatable object +# ───────────────────────────────────────────────────────────────────────────── +section "Step 2: Compile matmul_kernel.c → .o" + +KERNEL_O="${OUT_DIR}/matmul_kernel_riscv.o" + +if [[ "${HOST_MODE}" == "1" ]]; then + "${CLANG}" --target=x86_64-linux-gnu \ + -O2 -c -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" -o "${KERNEL_O}" + ok "x86 scalar kernel: ${KERNEL_O}" +else + "${CLANG}" --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -c -nostdinc -isystem "${CLANG_INC}" \ + -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" -o "${KERNEL_O}" + ok "RISC-V RVV kernel: ${KERNEL_O} ($(file -b "${KERNEL_O}" | cut -d, -f1))" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 3: Create lld wrapper +# +# IREE calls its embedded linker as: +# lld -flavor gnu --no-undefined -nostdlib -static -shared ... dispatch.o +# We append the kernel .o so my_matmul_kernel resolves at link time. +# -Bsymbolic: bind all same-ELF symbols locally to avoid R_RISCV_JUMP_SLOT +# entries in .rela.plt — IREE's embedded ELF loader ignores DT_JMPREL +# (.rela.plt) and only processes DT_RELA (.rela.dyn). Without this flag, +# the PLT GOT slot is never patched, causing a segfault on the first call. +# ───────────────────────────────────────────────────────────────────────────── +section "Step 3: Create lld_wrapper.sh" + +LLD_WRAPPER="${OUT_DIR}/lld_wrapper.sh" +cat > "${LLD_WRAPPER}" << WRAPPER +#!/usr/bin/env bash +exec "${LLD}" "\$@" "${KERNEL_O}" -Bsymbolic +WRAPPER +chmod +x "${LLD_WRAPPER}" +ok "${LLD_WRAPPER} (appends ${KERNEL_O})" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 4: iree-compile → .vmfb (kernel statically linked) +# ───────────────────────────────────────────────────────────────────────────── +section "Step 4: iree-compile → .vmfb (static)" + +if [[ "${HOST_MODE}" == "1" ]]; then + RISCV_FLAGS=() +else + RISCV_FLAGS=( + "--iree-llvmcpu-target-triple=riscv64-linux-gnu" + "--iree-llvmcpu-target-abi=lp64d" + "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+d,+c,+zvl512b,+v" + "--riscv-v-fixed-length-vector-lmul-max=8" + ) +fi + +VMFB="${OUT_DIR}/matmul_static.vmfb" +"${IREE_COMPILE}" \ + --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-link-embedded=true \ + --iree-llvmcpu-embedded-linker-path="${LLD_WRAPPER}" \ + --iree-mips-static-embedding \ + "${RISCV_FLAGS[@]}" \ + "${OUT_DIR}/flow.mlir" -o "${VMFB}" +ok "${VMFB} ($(du -sh "${VMFB}" | cut -f1))" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 5: Verify kernel is embedded in the dispatch ELF +# ───────────────────────────────────────────────────────────────────────────── +section "Step 5: Verify static embedding" + +ELF_OFFSET=$(grep -boa $'\x7fELF' "${VMFB}" 2>/dev/null | head -1 | cut -d: -f1 || true) +if [[ -n "${ELF_OFFSET}" ]]; then + dd if="${VMFB}" bs=1 skip="${ELF_OFFSET}" 2>/dev/null > "${OUT_DIR}/dispatch.elf" + python3 - "${OUT_DIR}/dispatch.elf" << 'PYEOF' +import sys +data = open(sys.argv[1], 'rb').read() +idx = data.find(b'my_matmul_kernel') +rvv = sum(1 for i in range(0, len(data)-3, 4) if data[i] & 0x7f == 0x57) +if idx != -1: + tag = "[ok]" if rvv > 0 else "[warn]" + print(f" {tag} 'my_matmul_kernel' at offset {idx}, RVV instructions: {rvv}") +else: + print(" [warn] 'my_matmul_kernel' not found in dispatch ELF") +PYEOF +else + echo " [warn] No ELF found in vmfb" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 6: Run +# ───────────────────────────────────────────────────────────────────────────── +section "Step 6: Run (no --executable_plugin)" + +MATMUL_ARGS=( + --module="${VMFB}" + --function="matmul_4x4" + "--input=4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" + "--input=4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" +) + +if [[ "${HOST_MODE}" == "1" ]]; then + echo " Running on x86 host (scalar fallback)..." + "${HOST_RUN}" "${MATMUL_ARGS[@]}" +else + echo " Running under QEMU vlen=${VLEN}..." + run_qemu "${VLEN}" "${MATMUL_ARGS[@]}" + + echo "" + echo " VLEN sweep:" + for V in 128 256 512; do + printf " vlen=%-4s " "${V}:" + run_qemu "${V}" "${MATMUL_ARGS[@]}" 2>&1 | grep "4x4xf32" || echo "(no output)" + done + echo " Note: vlen=128 may produce zeros — vmfb compiled with +zvl512b" +fi + +echo "" +echo " Expected: 4x4xf32=[1 2 3 4][5 6 7 8][9 10 11 12][13 14 15 16]" + +# ───────────────────────────────────────────────────────────────────────────── +# Summary +# ───────────────────────────────────────────────────────────────────────────── +echo "" +echo "════════════════════════════════════════════════════════════" +echo " DONE — Static embedding verified." +echo " Artifacts in ${OUT_DIR}/" +echo " matmul_kernel_riscv.o — kernel object (baked into vmfb)" +echo " lld_wrapper.sh — linker interceptor" +echo " matmul_static.vmfb — self-contained vmfb (no plugin at runtime)" +echo "════════════════════════════════════════════════════════════" diff --git a/build_tools/riscv/setup_qemu_workflow.sh b/build_tools/riscv/setup_qemu_workflow.sh new file mode 100644 index 000000000000..792cf1353f87 --- /dev/null +++ b/build_tools/riscv/setup_qemu_workflow.sh @@ -0,0 +1,278 @@ +#!/usr/bin/env bash +# setup_qemu_workflow.sh +# +# One-time setup for the MIPS/RVV QEMU workflow on a Linux host (Rocky 8). +# Installs toolchain, QEMU, and builds both IREE host and RISC-V targets. +# +# Steps (run all by default; pass --step=N to run one): +# 1. Install toolchain — ninja, clang-18, lld-18 via conda-forge +# 2. Install sysroot — RISC-V prebuilt sysroot + iree-run-module (RISC-V) +# 3. Build QEMU — qemu-riscv64 user-mode from source +# 4. Build IREE (host) — iree-opt, iree-compile, iree-run-module for x86 +# 5. Build IREE (riscv) — iree-run-module cross-compiled for RISC-V +# +# Usage: +# bash setup_qemu_workflow.sh # run all steps +# bash setup_qemu_workflow.sh --step=3 # run only QEMU build +# bash setup_qemu_workflow.sh --step=4 # run only host IREE build + +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration — edit to match your environment +# ───────────────────────────────────────────────────────────────────────────── +WORK_DIR="${HOME}/MLIR_Work/mips" +IREE_SRC="${WORK_DIR}/iree" + +HOST_BUILD="${WORK_DIR}/iree-build" # iree-opt, iree-compile (x86) +HOST_INSTALL="${HOST_BUILD}/install" # installed host tools +RISCV_BUILD="${WORK_DIR}/iree-build-riscv" # iree-run-module (RISC-V) +QEMU_VER="8.2.2" +INSTALL_PREFIX="${HOME}/local" # qemu-riscv64 installed here + +CONDA="${HOME}/miniforge3/bin/conda" +CLANG="${HOME}/miniforge3/bin/clang" +CLANGXX="${HOME}/miniforge3/bin/clang++" +NINJA="${INSTALL_PREFIX}/bin/ninja" + +# RISC-V prebuilt sysroot (downloaded by riscv_bootstrap.sh) +SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" +RISCV_TOOLCHAIN="${HOME}/riscv/toolchain/clang/linux/RISCV" + +# Rocky 8's system libstdc++ is too old; conda's copy has GLIBCXX 3.4.29+. +export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── +STEP_ONLY=0 +for arg in "$@"; do + case "${arg}" in + --step=*) STEP_ONLY="${arg#--step=}" ;; + *) echo "Unknown arg: ${arg}"; exit 1 ;; + esac +done + +should_run() { [[ "${STEP_ONLY}" == "0" || "${STEP_ONLY}" == "$1" ]]; } + +log() { echo ""; echo "════════════════════════════════════════════════════════════"; echo " $*"; echo "════════════════════════════════════════════════════════════"; } +ok() { echo " [ok] $*"; } +skip() { echo " [skip] $*"; } +die() { echo " [FAIL] $*" >&2; exit 1; } + +# ───────────────────────────────────────────────────────────────────────────── +# Step 1: Install toolchain (ninja, clang-18, lld-18 via conda-forge) +# ───────────────────────────────────────────────────────────────────────────── +step1_toolchain() { + log "STEP 1: Install toolchain" + + # Miniforge (conda base) + if [[ -x "${CONDA}" ]]; then + skip "conda already at ${CONDA}" + else + local tmp; tmp="$(mktemp /tmp/miniforge_XXXXX.sh)" + echo " Downloading Miniforge..." + curl -fsSL "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \ + -o "${tmp}" + bash "${tmp}" -b -p "${HOME}/miniforge3" + rm -f "${tmp}" + ok "conda installed" + fi + + # ninja + if [[ -x "${NINJA}" ]]; then + skip "ninja already at ${NINJA}" + else + local tmp; tmp="$(mktemp /tmp/ninja_XXXXX.zip)" + curl -fsSL "https://github.com/ninja-build/ninja/releases/download/v1.12.1/ninja-linux.zip" \ + -o "${tmp}" + mkdir -p "${INSTALL_PREFIX}/bin" + unzip -qo "${tmp}" -d "${INSTALL_PREFIX}/bin" + chmod +x "${NINJA}" + rm -f "${tmp}" + ok "ninja installed" + fi + + # clang-18 + lld-18 + if [[ -x "${CLANG}" ]]; then + skip "clang already at ${CLANG} ($(${CLANG} --version | head -1))" + else + echo " Installing clang-18 + lld-18 (this may take a few minutes)..." + "${CONDA}" install -y -c conda-forge "clang=18" "clangxx=18" "lld=18" --no-update-deps + ok "clang-18 + lld-18 installed" + fi + + ok "Toolchain ready" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 2: Install RISC-V sysroot (IREE prebuilt) +# ───────────────────────────────────────────────────────────────────────────── +step2_sysroot() { + log "STEP 2: Install RISC-V sysroot" + + if [[ -d "${SYSROOT}" ]]; then + skip "Sysroot already at ${SYSROOT}" + return + fi + + echo " Running riscv_bootstrap.sh (interactive — prompts for download paths)..." + bash "$(dirname "${BASH_SOURCE[0]}")/riscv_bootstrap.sh" + ok "Sysroot installed at ${SYSROOT}" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 3: Build QEMU riscv64-linux-user from source +# ───────────────────────────────────────────────────────────────────────────── +step3_qemu() { + log "STEP 3: Build QEMU ${QEMU_VER} (riscv64-linux-user)" + + local qemu_bin="${INSTALL_PREFIX}/bin/qemu-riscv64" + if [[ -x "${qemu_bin}" ]]; then + skip "qemu-riscv64 already at ${qemu_bin} ($(${qemu_bin} --version | head -1))" + return + fi + + "${CONDA}" install -y -c conda-forge glib pkg-config 2>&1 | tail -3 + + local tarball="${WORK_DIR}/qemu-${QEMU_VER}.tar.xz" + if [[ ! -f "${tarball}" ]]; then + echo " Downloading QEMU ${QEMU_VER}..." + curl -fsSL --progress-bar "https://download.qemu.org/qemu-${QEMU_VER}.tar.xz" -o "${tarball}" + else + skip "Tarball already downloaded" + fi + + local src="${WORK_DIR}/qemu-${QEMU_VER}" + if [[ ! -d "${src}" ]]; then + echo " Extracting QEMU source..." + tar -xf "${tarball}" -C "${WORK_DIR}" + fi + + export PKG_CONFIG_PATH="${HOME}/miniforge3/lib/pkgconfig:${HOME}/miniforge3/share/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}}" + export PKG_CONFIG="${HOME}/miniforge3/bin/pkg-config" + export LDFLAGS="-Wl,-rpath,${HOME}/miniforge3/lib" + + echo " Configuring QEMU..." + cd "${src}" + ./configure \ + --prefix="${INSTALL_PREFIX}" \ + --target-list="riscv64-linux-user" \ + --disable-system \ + --enable-linux-user \ + --disable-werror \ + --disable-docs \ + --disable-gtk \ + --disable-sdl \ + --disable-vnc \ + --disable-curl \ + --disable-capstone \ + --disable-kvm \ + --without-default-features \ + --enable-user + + echo " Building QEMU ($(nproc) jobs)..." + if [[ -f "${src}/build/build.ninja" ]]; then + "${NINJA}" -C "${src}/build" -j"$(nproc)" + "${NINJA}" -C "${src}/build" install + else + make -j"$(nproc)" + make install + fi + + ok "qemu-riscv64 installed at ${qemu_bin}" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 4: Build IREE host (iree-opt, iree-compile, iree-run-module for x86) +# ───────────────────────────────────────────────────────────────────────────── +step4_iree_host() { + log "STEP 4: Build IREE (host — iree-opt, iree-compile, iree-run-module)" + + mkdir -p "${HOST_BUILD}" + + cmake -S "${IREE_SRC}" -B "${HOST_BUILD}" \ + -G Ninja \ + -DCMAKE_MAKE_PROGRAM="${NINJA}" \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_C_COMPILER="${CLANG}" \ + -DCMAKE_CXX_COMPILER="${CLANGXX}" \ + -DCMAKE_ASM_COMPILER="${CLANG}" \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_INSTALL_PREFIX="${HOST_INSTALL}" \ + -DIREE_ENABLE_ASSERTIONS=ON \ + -DIREE_ENABLE_SPLIT_DWARF=ON \ + -DIREE_ENABLE_LLD=ON \ + -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ + -DIREE_TARGET_BACKEND_LLVM_CPU=ON \ + -DIREE_HAL_DRIVER_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ + -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ + -DIREE_BUILD_PYTHON_BINDINGS=OFF \ + -DBENCHMARK_ENABLE_TESTING=OFF \ + -DHAVE_STD_REGEX=ON \ + -DHAVE_POSIX_REGEX=OFF + + echo " Building ($(nproc) jobs)..." + "${NINJA}" -C "${HOST_BUILD}" -j"$(nproc)" iree-opt iree-compile iree-run-module iree-tblgen + + echo " Installing host tools to ${HOST_INSTALL}..." + "${NINJA}" -C "${HOST_BUILD}" install/fast + + ok "iree-opt: ${HOST_INSTALL}/bin/iree-opt" + ok "iree-compile: ${HOST_INSTALL}/bin/iree-compile" + ok "iree-run-module: ${HOST_INSTALL}/bin/iree-run-module" + ok "iree-tblgen: ${HOST_INSTALL}/bin/iree-tblgen" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 5: Build IREE RISC-V (iree-run-module cross-compiled for riscv64) +# ───────────────────────────────────────────────────────────────────────────── +step5_iree_riscv() { + log "STEP 5: Build IREE (RISC-V cross — iree-run-module for riscv64)" + + [[ -f "${HOST_INSTALL}/bin/iree-tblgen" ]] || \ + die "Host install not found at ${HOST_INSTALL}/bin — run step 4 first." + + mkdir -p "${RISCV_BUILD}" + + cmake -S "${IREE_SRC}" -B "${RISCV_BUILD}" \ + -G Ninja \ + -DCMAKE_MAKE_PROGRAM="${NINJA}" \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_TOOLCHAIN_FILE="${IREE_SRC}/build_tools/cmake/riscv.toolchain.cmake" \ + -DIREE_HOST_BIN_DIR="${HOST_INSTALL}/bin" \ + -DRISCV_TOOLCHAIN_ROOT="${RISCV_TOOLCHAIN}" \ + -DIREE_BUILD_COMPILER=OFF \ + -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ + -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ + -DIREE_BUILD_PYTHON_BINDINGS=OFF \ + -DBENCHMARK_ENABLE_TESTING=OFF \ + -DCMAKE_INSTALL_PREFIX="${RISCV_BUILD}/install" + + echo " Building ($(nproc) jobs)..." + "${NINJA}" -C "${RISCV_BUILD}" -j"$(nproc)" iree-run-module + "${NINJA}" -C "${RISCV_BUILD}" install/fast + + ok "iree-run-module (riscv64): ${RISCV_BUILD}/install/bin/iree-run-module" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── +should_run 1 && step1_toolchain +should_run 2 && step2_sysroot +should_run 3 && step3_qemu +should_run 4 && step4_iree_host +should_run 5 && step5_iree_riscv + +echo "" +echo "════════════════════════════════════════════════════════════" +echo " Setup complete. Run the end-to-end workflows:" +echo "" +echo " bash build_tools/riscv/rvv_qemu_workflow_static.sh" +echo " bash build_tools/riscv/rvv_qemu_workflow_dynamic.sh" +echo "════════════════════════════════════════════════════════════" diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt index 2738c8d9fa99..d9d0fa598477 100644 --- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt @@ -37,6 +37,7 @@ iree_cc_library( "BindSymbolicShapes.cpp" "BitCastTensor.cpp" "ConvertTMTensorToLinalgExt.cpp" + "ConvertTorchToMIPS.cpp" "ConvertTorchUnstructuredToLinalgExt.cpp" "FuncConversion.cpp" "SetStrictSymbolicShapes.cpp" @@ -57,6 +58,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR + iree::compiler::Dialect::MIPS::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::TensorExt::IR PUBLIC diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp new file mode 100644 index 000000000000..3c377e93e668 --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp @@ -0,0 +1,298 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Converts torch matmul ops → mips.matmul. +// +// Patterns handled: +// ConvertAtenMmToMIPSMatmul — torch.aten.mm (f32 or i8 inputs) +// ConvertAtenIntMmToMIPSMatmul — torch.aten._int_mm (i8 × i8 → i32) +// +// Both patterns run inside the Torch input-conversion pipeline, BEFORE +// createConvertTorchToLinalgPass(), so they intercept the ops first. +// +// Since torch ops carry ValueTensorType (torch's tensor type), each pattern: +// 1. Casts operands to builtin RankedTensorType via ToBuiltinTensorOp. +// 2. Creates a zero-initialised init tensor (Destination Passing Style). +// 3. Emits mips.matmul on builtin tensors. +// 4. Casts the result back to ValueTensorType via FromBuiltinTensorOp. +// +// The mips.matmul op is eliminated during One-Shot Bufferize: +// MIPSBufferizableOpInterface detects the LHS element type and calls either +// my_matmul_kernel (f32) or my_matmul_kernel_i8 (i8→i32). +// + +#include "compiler/plugins/input/Torch/InputConversion/Passes.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +namespace mlir::iree_compiler::TorchInput { + +#define GEN_PASS_DEF_CONVERTTORCHTOMIPSPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Helper: normalize signed/unsigned integer types to signless. +// +// arith.constant (and most MLIR arithmetic ops) require signless integers. +// Torch's dtype mapping produces signed types (e.g. si8, si32) which must be +// converted to their signless equivalents (i8, i32) before entering arith/ +// linalg/tensor dialects. +//===----------------------------------------------------------------------===// + +static Type toSignlessElemType(MLIRContext *ctx, Type ty) { + if (auto intTy = dyn_cast(ty)) + if (!intTy.isSignless()) + return IntegerType::get(ctx, intTy.getWidth()); + return ty; +} + +static RankedTensorType toSignlessTensorType(RankedTensorType ty) { + Type elem = toSignlessElemType(ty.getContext(), ty.getElementType()); + if (elem == ty.getElementType()) return ty; + return RankedTensorType::get(ty.getShape(), elem, ty.getEncoding()); +} + +//===----------------------------------------------------------------------===// +// Helper: create a zero-filled tensor of a given shape and element type. +// Accepts (M, N) as dynamic Value dimensions. +//===----------------------------------------------------------------------===// + +static Value createZeroTensor(PatternRewriter &rewriter, Location loc, + RankedTensorType ty, ValueRange dynSizes) { + // Use signless element types — arith.constant rejects signed integers. + RankedTensorType signlessTy = toSignlessTensorType(ty); + Value empty = tensor::EmptyOp::create(rewriter, loc, signlessTy, dynSizes); + Attribute zeroAttr = rewriter.getZeroAttr(signlessTy.getElementType()); + Value zero = arith::ConstantOp::create(rewriter, loc, cast(zeroAttr)); + return linalg::FillOp::create(rewriter, loc, zero, empty).result(); +} + +//===----------------------------------------------------------------------===// +// Helper: check whether a torch dtype is one we handle in mips.matmul. +// Returns true for f32 and si8/i8 inputs. +//===----------------------------------------------------------------------===// + +static bool isSupportedMmDtype(Type torchDtype) { + return torchDtype.isF32() || torchDtype.isSignedInteger(8) || + torchDtype.isInteger(8); +} + +//===----------------------------------------------------------------------===// +// Pattern: torch.aten.mm → mips.matmul +// +// Handles f32 × f32 → f32 (original path) +// and i8 × i8 → i8 (rare, but supported via same mips.matmul) +//===----------------------------------------------------------------------===// + +struct ConvertAtenMmToMIPSMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(torch::Torch::AtenMmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // ---------------------------------------------------------------- + // 1. Verify that we have supported tensor types. + // ---------------------------------------------------------------- + auto lhsTorchTy = + dyn_cast(op.getSelf().getType()); + auto rhsTorchTy = + dyn_cast(op.getMat2().getType()); + auto resultTorchTy = + dyn_cast(op.getType()); + + if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy) + return rewriter.notifyMatchFailure(op, "expected ValueTensorType"); + + if (!isSupportedMmDtype(lhsTorchTy.getDtype())) + return rewriter.notifyMatchFailure(op, "unsupported dtype (f32 or i8 only)"); + + // ---------------------------------------------------------------- + // 2. Cast operands from torch ValueTensorType → builtin RankedTensorType. + // ---------------------------------------------------------------- + auto lhsBuiltinTy = dyn_cast_or_null( + lhsTorchTy.toBuiltinTensor()); + auto rhsBuiltinTy = dyn_cast_or_null( + rhsTorchTy.toBuiltinTensor()); + auto resultBuiltinTy = dyn_cast_or_null( + resultTorchTy.toBuiltinTensor()); + + if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy || + lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors"); + + // Normalize signed integer element types to signless (arith requires it). + lhsBuiltinTy = toSignlessTensorType(lhsBuiltinTy); + rhsBuiltinTy = toSignlessTensorType(rhsBuiltinTy); + resultBuiltinTy = toSignlessTensorType(resultBuiltinTy); + + Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, lhsBuiltinTy, op.getSelf()); + Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, rhsBuiltinTy, op.getMat2()); + + // ---------------------------------------------------------------- + // 3. Collect dynamic dimension values for the result tensor (M, N). + // ---------------------------------------------------------------- + SmallVector dynSizes; + if (resultBuiltinTy.isDynamicDim(0)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, lhs, 0)); + if (resultBuiltinTy.isDynamicDim(1)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, rhs, 1)); + + // ---------------------------------------------------------------- + // 4. Create a zero-initialised init tensor for DPS output. + // ---------------------------------------------------------------- + Value init = createZeroTensor(rewriter, loc, resultBuiltinTy, dynSizes); + + // ---------------------------------------------------------------- + // 5. Emit mips.matmul on builtin tensors. + // ---------------------------------------------------------------- + Value result = + IREE::MIPS::MatmulOp::create(rewriter, loc, TypeRange{resultBuiltinTy}, + lhs, rhs, init) + .getResult(); + + // ---------------------------------------------------------------- + // 6. Cast result back to ValueTensorType so downstream torch passes can + // still operate on it until the type finalisation pass runs. + // ---------------------------------------------------------------- + Value torchResult = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, resultTorchTy, result); + + rewriter.replaceOp(op, torchResult); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pattern: torch.aten._int_mm → mips.matmul +// +// torch.aten._int_mm: i8 × i8 → i32 (integer matrix multiply). +// This is the primary op produced by INT8 quantization pipelines (e.g. +// torch.ao.quantization, torchao). +// +// mips.matmul carries i8 LHS/RHS and i32 output; MIPSBufferizableOpInterface +// detects the i8 LHS element type and emits func.call @my_matmul_kernel_i8. +//===----------------------------------------------------------------------===// + +struct ConvertAtenIntMmToMIPSMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(torch::Torch::Aten_IntMmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // ---------------------------------------------------------------- + // 1. Verify operand and result types. + // ---------------------------------------------------------------- + auto lhsTorchTy = + dyn_cast(op.getSelf().getType()); + auto rhsTorchTy = + dyn_cast(op.getMat2().getType()); + auto resultTorchTy = + dyn_cast(op.getType()); + + if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy) + return rewriter.notifyMatchFailure(op, "expected ValueTensorType"); + + // _int_mm expects i8 inputs and i32 output. + if (!lhsTorchTy.getDtype().isSignedInteger(8) && + !lhsTorchTy.getDtype().isInteger(8)) + return rewriter.notifyMatchFailure(op, "expected i8 lhs"); + + // ---------------------------------------------------------------- + // 2. Cast to builtin tensor types (signless integers — arith requires it). + // ---------------------------------------------------------------- + auto lhsBuiltinTy = dyn_cast_or_null( + lhsTorchTy.toBuiltinTensor()); + auto rhsBuiltinTy = dyn_cast_or_null( + rhsTorchTy.toBuiltinTensor()); + auto resultBuiltinTy = dyn_cast_or_null( + resultTorchTy.toBuiltinTensor()); + + if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy || + lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors"); + + lhsBuiltinTy = toSignlessTensorType(lhsBuiltinTy); + rhsBuiltinTy = toSignlessTensorType(rhsBuiltinTy); + resultBuiltinTy = toSignlessTensorType(resultBuiltinTy); + + Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, lhsBuiltinTy, op.getSelf()); + Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, rhsBuiltinTy, op.getMat2()); + + // ---------------------------------------------------------------- + // 3. Dynamic dims for the i32 result tensor (M from lhs, N from rhs). + // ---------------------------------------------------------------- + SmallVector dynSizes; + if (resultBuiltinTy.isDynamicDim(0)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, lhs, 0)); + if (resultBuiltinTy.isDynamicDim(1)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, rhs, 1)); + + // ---------------------------------------------------------------- + // 4. Zero-initialised i32 init tensor. + // ---------------------------------------------------------------- + Value init = createZeroTensor(rewriter, loc, resultBuiltinTy, dynSizes); + + // ---------------------------------------------------------------- + // 5. Emit mips.matmul — LHS is i8, result is i32. + // MIPSBufferizableOpInterface dispatches to my_matmul_kernel_i8. + // ---------------------------------------------------------------- + Value result = + IREE::MIPS::MatmulOp::create(rewriter, loc, TypeRange{resultBuiltinTy}, + lhs, rhs, init) + .getResult(); + + // ---------------------------------------------------------------- + // 6. Cast result back to torch ValueTensorType. + // ---------------------------------------------------------------- + Value torchResult = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, resultTorchTy, result); + + rewriter.replaceOp(op, torchResult); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct ConvertTorchToMIPSPass + : impl::ConvertTorchToMIPSPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index ce0cb9c34f36..888d400db5ea 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -8,6 +8,8 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" @@ -63,6 +65,11 @@ void createTorchToIREEPipeline( pm.addNestedPass(torch::createConvertTorchToTensorPass()); pm.addNestedPass( TorchInput::createConvertTorchUnstructuredToLinalgExtPass()); + // MIPS: When enabled, intercept aten.mm before the standard torch->linalg + // pass and route it through mips.matmul -> func.call @my_matmul_kernel. + if (options.useMIPSMatmul) { + pm.addNestedPass(TorchInput::createConvertTorchToMIPSPass()); + } pm.addNestedPass(torch::createConvertTorchToLinalgPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(torch::createConvertTorchToSCFPass()); diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.h b/compiler/plugins/input/Torch/InputConversion/Passes.h index 23995b21a943..8eb22c3a6034 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.h +++ b/compiler/plugins/input/Torch/InputConversion/Passes.h @@ -43,6 +43,12 @@ struct TorchToIREELoweringPipelineOptions "program inputs. This buffer will be used for storing transient " "memory and must be provided by the user."), llvm::cl::init(false)}; + Option useMIPSMatmul{ + *this, "use-mips-matmul", + llvm::cl::desc("If enabled, lowers torch.aten.mm through the MIPS " + "custom dialect (mips.matmul) instead of the standard " + "torch->linalg path."), + llvm::cl::init(false)}; }; // Creates a pipeline that lowers from the torch backend contract to IREE. diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td index a868d4bb8354..e030ff6ba882 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/InputConversion/Passes.td @@ -29,6 +29,17 @@ def ConvertTorchUnstructuredToLinalgExtPass : let summary = "Convert unstructured Torch ops to LinalgExt ops"; } +def ConvertTorchToMIPSPass : + InterfacePass<"torch-iree-to-mips-matmul", "mlir::FunctionOpInterface"> { + let summary = "Convert torch.aten.mm to mips.matmul"; + let description = [{ + Intercepts torch.aten.mm before the standard torch->linalg conversion and + replaces it with mips.matmul. The mips.matmul op is eliminated entirely + during One-Shot Bufferize: MIPSBufferizableOpInterface emits a direct + func.call @my_matmul_kernel with decomposed memref arguments. + }]; +} + def SetStrictSymbolicShapesPass : InterfacePass<"torch-iree-set-strict-symbolic-shapes", "mlir::FunctionOpInterface"> { let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR"; diff --git a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt index b1785a708878..3d451b68f13a 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt @@ -9,6 +9,7 @@ iree_lit_test_suite( "attention.mlir" "bind_symbolic_shapes.mlir" "bitcast_tensor.mlir" + "convert_torch_to_mips.mlir" "func_conversion.mlir" "func_conversion_invalid.mlir" "func_conversion_transients.mlir" diff --git a/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir b/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir new file mode 100644 index 000000000000..98d5407fefc5 --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir @@ -0,0 +1,37 @@ +// RUN: iree-opt --split-input-file \ +// RUN: --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \ +// RUN: %s | FileCheck %s + +// ───────────────────────────────────────────────────────────────────────────── +// Static-shape: torch.aten.mm on f32 tensors → mips.matmul +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK-LABEL: func.func @mm_static +// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<4x8xf32> +// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<8x4xf32> +// CHECK: mips.matmul {{.*}} : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32> +// CHECK-NOT: torch.aten.mm +func.func @mm_static(%A: !torch.vtensor<[4,8],f32>, + %B: !torch.vtensor<[8,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} + +// ───────────────────────────────────────────────────────────────────────────── +// Non-f32 (i32) should be left untouched (pattern rejects non-f32 dtypes). +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK-LABEL: func.func @mm_i32_unchanged +// CHECK-NOT: mips.matmul +// CHECK: torch.aten.mm +func.func @mm_i32_unchanged(%A: !torch.vtensor<[4,8],si32>, + %B: !torch.vtensor<[8,4],si32>) + -> !torch.vtensor<[4,4],si32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,8],si32>, !torch.vtensor<[8,4],si32> + -> !torch.vtensor<[4,4],si32> + return %0 : !torch.vtensor<[4,4],si32> +} diff --git a/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt new file mode 100644 index 000000000000..3da5a7a85912 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Recursively picks up IR/ and Transforms/ subdirectories. +iree_add_all_subdirs() diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt new file mode 100644 index 000000000000..632b1ee5b190 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt @@ -0,0 +1,70 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +# ─── Tablegen ──────────────────────────────────────────────────────────────── +# Generate op declarations/definitions AND dialect declarations/definitions +# from a single MIPSOps.td source (same pattern as LinalgExt). + +iree_tablegen_library( + NAME + MIPSOpsIncGen + TD_FILE + "MIPSOps.td" + OUTS + --gen-op-decls MIPSOps.h.inc + --gen-op-defs MIPSOps.cpp.inc + --dialect=mips --gen-dialect-decls MIPSDialect.h.inc + --dialect=mips --gen-dialect-defs MIPSDialect.cpp.inc +) + +# ─── C++ library ───────────────────────────────────────────────────────────── + +iree_cc_library( + NAME + IR + HDRS + "MIPSDialect.h" + "MIPSOps.h" + "MIPSDialect.h.inc" + TEXTUAL_HDRS + "MIPSOps.h.inc" + "MIPSOps.cpp.inc" + SRCS + "MIPSDialect.cpp" + "MIPSDialect.cpp.inc" + "MIPSOps.cpp" + "MIPSBufferizableOpInterface.cpp" + DEPS + ::MIPSOpsIncGen + LLVMSupport + MLIRIR + MLIRSupport + MLIRFuncDialect + MLIRMemRefDialect + MLIRTensorDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface + MLIRInferTypeOpInterface + MLIRSideEffectInterfaces + MLIRTensorUtils + MLIRTransforms + PUBLIC +) + +# ─── Documentation ─────────────────────────────────────────────────────────── + +iree_tablegen_doc( + NAME + MIPSDialectDocGen + CATEGORY "Dialects" + TD_FILE + "MIPSOps.td" + OUTS + --gen-dialect-doc -dialect=mips MIPSDialect.md +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td new file mode 100644 index 000000000000..470c2d624549 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td @@ -0,0 +1,59 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MIPS_BASE +#define IREE_DIALECT_MIPS_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// MIPS dialect definition +//===----------------------------------------------------------------------===// + +def MIPS_Dialect : Dialect { + let name = "mips"; + let cppNamespace = "::mlir::iree_compiler::IREE::MIPS"; + let summary = "Semantic dialect layer for dispatching to highly-optimized MIPS kernels."; + let description = [{ + The MIPS dialect provides a semantic abstraction layer that bridges + high-level tensor operations and hand-tuned, target-specific kernel + implementations (e.g. RVV-vectorized routines). + + Ops in this dialect live entirely in the tensor domain — no memref forms + are ever produced. They are eliminated during One-Shot Bufferize: + the BufferizableOpInterface implementation allocates output buffers and + emits direct `func.call` instructions to the underlying C kernels. + + The semantic level allows the compiler to apply higher-level + transformations before lowering, such as: + - Fusing producer ops (e.g. transposes, type promotions) into a single + kernel call to avoid intermediate allocations. + - Tiling and packing decisions based on target vector width. + - Selecting between kernel variants (e.g. RVV LMUL=m4 vs. m8) based + on operand shapes or hardware configuration. + + Pipeline: + torch.aten.mm → mips.matmul → func.call @my_matmul_kernel + }]; + + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::tensor::TensorDialect" + ]; + + // No custom attribute types → do not declare parseAttribute/printAttribute + // overrides. The base Dialect class handles the fallback behavior. + let useDefaultAttributePrinterParser = 0; +} + +//===----------------------------------------------------------------------===// +// Base op class +//===----------------------------------------------------------------------===// + +class MIPS_Op traits = []> + : Op; + +#endif // IREE_DIALECT_MIPS_BASE diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp new file mode 100644 index 000000000000..551192ab492b --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp @@ -0,0 +1,258 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Implements BufferizableOpInterface for mips.matmul. +// +// mips.matmul is a tensor-only, Destination-Passing-Style (DPS) op. It is +// eliminated *entirely* during One-Shot Bufferize: bufferize() obtains memref +// buffers for all three operands, decomposes each 2-D memref into +// (base_ptr, offset, stride0, stride1) via memref.extract_strided_metadata, +// and emits a func.call to the appropriate kernel directly. +// +// The kernel is selected based on the LHS element type: +// f32 → func.call @my_matmul_kernel (f32 × f32 → f32) +// i8 → func.call @my_matmul_kernel_i8 (i8 × i8 → i32) +// +// No memref form of mips.matmul ever exists in the IR. +// +// Before bufferization (f32 example): +// %C = mips.matmul %A, %B, %init +// : tensor, tensor, tensor -> tensor +// +// After bufferization (produced inside bufferize()): +// %A_meta = memref.extract_strided_metadata %A_buf -> (base, off, s0, s1) +// %B_meta = memref.extract_strided_metadata %B_buf -> (base, off, s0, s1) +// %C_meta = memref.extract_strided_metadata %C_buf -> (base, off, s0, s1) +// %M = memref.dim %A_buf, 0 +// %N = memref.dim %B_buf, 1 +// %K = memref.dim %A_buf, 1 +// call @my_matmul_kernel(%A_base, %A_off, %A_s0, %A_s1, +// %B_base, %B_off, %B_s0, %B_s1, +// %C_base, %C_off, %C_s0, %C_s1, +// %M, %N, %K) +// -- tensor result replaced by %C_buf via replaceOpWithBufferizedValues -- +// +// The INT8 path is identical but uses memref / memref base pointers +// and calls @my_matmul_kernel_i8 instead. + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" + +using namespace mlir; +using namespace mlir::bufferization; + +// When true, matmul kernels are emitted as direct linker-resolved calls +// (hal.import.static) instead of dynamic HAL import table entries. +// Pass --iree-mips-static-embedding to iree-compile to enable. +// Mutually exclusive with --executable_plugin at runtime. +static llvm::cl::opt clMIPSStaticEmbedding( + "iree-mips-static-embedding", + llvm::cl::desc( + "Emit mips matmul kernels as direct linker-resolved calls " + "(hal.import.static) instead of dynamic HAL imports. " + "Requires the kernel .o to be appended by lld_wrapper at compile " + "time. Mutually exclusive with --executable_plugin at runtime."), + llvm::cl::init(false)); + +namespace mlir::iree_compiler::IREE::MIPS { +namespace { + +static constexpr StringLiteral kKernelF32 = "my_matmul_kernel"; +static constexpr StringLiteral kKernelI8 = "my_matmul_kernel_i8"; + +//===----------------------------------------------------------------------===// +// Helper: ensure func.func private @ exists at module scope. +// +// The declaration carries {llvm.bareptr = true} so the LLVM backend passes +// bare pointer arguments instead of MLIR memref descriptor structs, matching +// the C kernel ABI. +//===----------------------------------------------------------------------===// + +static func::FuncOp ensureKernelDeclaration(RewriterBase &rewriter, + Operation *moduleOp, + StringRef kernelName, + FunctionType fnType, + Location loc) { + if (auto existing = dyn_cast_if_present( + SymbolTable::lookupSymbolIn(moduleOp, kernelName))) + return existing; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front()); + auto fnDecl = func::FuncOp::create(rewriter, loc, kernelName, fnType); + SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private); + fnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true)); + // If --iree-mips-static-embedding was passed, emit a direct linker call + // instead of a dynamic HAL import table entry. + if (clMIPSStaticEmbedding) + fnDecl->setAttr("hal.import.static", rewriter.getUnitAttr()); + return fnDecl; +} + +//===----------------------------------------------------------------------===// +// Helper: decompose a 2-D memref into (base_ptr, offset, stride0, stride1). +// +// Uses memref.extract_strided_metadata. The base_ptr is always a rank-0 +// memref with DEFAULT address space (memref), regardless of the +// source memref's address space. Any IREE-specific memory space (e.g. +// #hal.descriptor_type) is stripped via +// memref.memory_space_cast so that: +// +// 1. The function declaration uses plain memref, which is stable +// across all pipeline stages. +// 2. eraseHALDescriptorTypeFromMemRefPass (which runs after bufferization and +// does NOT update external function declarations) cannot introduce a +// type mismatch between the call operands and the declaration. +// +// Combined with the {llvm.bareptr = true} attribute on the callee, the +// rank-0 memref lowers to a bare pointer matching the C ABI. +//===----------------------------------------------------------------------===// + +static void decomposeMemref2D(RewriterBase &rewriter, Location loc, + Value memref2D, + SmallVectorImpl &callOperands, + SmallVectorImpl &callArgTypes) { + Type indexType = IndexType::get(rewriter.getContext()); + + auto meta = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref2D); + + // Strip any IREE-specific memory space from the base pointer so the + // function declaration stays in the default address space. + Value basePtr = meta.getBaseBuffer(); + auto basePtrMemrefTy = cast(basePtr.getType()); + MemRefType plainBasePtrTy = + MemRefType::get(/*shape=*/{}, basePtrMemrefTy.getElementType()); + if (basePtrMemrefTy != plainBasePtrTy) { + basePtr = memref::MemorySpaceCastOp::create(rewriter, loc, plainBasePtrTy, + basePtr); + } + + callOperands.push_back(basePtr); + callArgTypes.push_back(plainBasePtrTy); + + callOperands.push_back(meta.getOffset()); + callArgTypes.push_back(indexType); + + for (Value stride : meta.getStrides()) { + callOperands.push_back(stride); + callArgTypes.push_back(indexType); + } +} + +//===----------------------------------------------------------------------===// +// External model — BufferizableOpInterface for mips.matmul. +// +// Inherits from DstBufferizableOpInterfaceExternalModel which automatically +// handles the DPS aliasing (init ↔ result) and write detection for the init +// operand. We override bufferizesToMemoryRead to mark lhs and rhs as read, +// and provide a custom bufferize() that selects the right kernel based on the +// LHS element type. +//===----------------------------------------------------------------------===// + +struct MIPSMatmulBufferizableOpInterface + : public DstBufferizableOpInterfaceExternalModel< + MIPSMatmulBufferizableOpInterface, MatmulOp> { + + // All three operands are read by the kernel. + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto matmulOp = cast(op); + Location loc = matmulOp.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(matmulOp); + + // Obtain memref buffers for all three tensor operands. + FailureOr lhsBuf = + getBuffer(rewriter, matmulOp.getLhs(), options, state); + if (failed(lhsBuf)) + return failure(); + FailureOr rhsBuf = + getBuffer(rewriter, matmulOp.getRhs(), options, state); + if (failed(rhsBuf)) + return failure(); + // init aliases with result — one-shot bufferize allocates the output buffer + // (via bufferization.alloc_tensor or in-place analysis) and gives it here. + FailureOr initBuf = + getBuffer(rewriter, matmulOp.getInit(), options, state); + if (failed(initBuf)) + return failure(); + + // Select the kernel based on LHS element type. + Type lhsElemTy = cast(lhsBuf->getType()).getElementType(); + StringRef kernelName; + if (lhsElemTy.isF32()) { + kernelName = kKernelF32; + } else if (lhsElemTy.isInteger(8)) { + kernelName = kKernelI8; + } else { + return matmulOp.emitOpError( + "MIPSBufferizableOpInterface: unsupported LHS element type '") + << lhsElemTy + << "'; supported types are f32 and i8"; + } + + // Build the flattened argument list for the kernel call. + // For each 2-D memref: (base_ptr, offset, stride0, stride1) + // Then: M, N, K as index scalars. + SmallVector callOperands; + SmallVector callArgTypes; + + decomposeMemref2D(rewriter, loc, *lhsBuf, callOperands, callArgTypes); + decomposeMemref2D(rewriter, loc, *rhsBuf, callOperands, callArgTypes); + decomposeMemref2D(rewriter, loc, *initBuf, callOperands, callArgTypes); + + Type indexType = IndexType::get(ctx); + Value M = memref::DimOp::create(rewriter, loc, *lhsBuf, 0); + Value N = memref::DimOp::create(rewriter, loc, *rhsBuf, 1); + Value K = memref::DimOp::create(rewriter, loc, *lhsBuf, 1); + callOperands.append({M, N, K}); + callArgTypes.append(3, indexType); + + // Declare the kernel function in the enclosing module (idempotent). + Operation *moduleOp = SymbolTable::getNearestSymbolTable(matmulOp); + FunctionType fnType = rewriter.getFunctionType(callArgTypes, TypeRange{}); + ensureKernelDeclaration(rewriter, moduleOp, kernelName, fnType, loc); + + // Emit the call — the kernel writes into *initBuf in place. + func::CallOp::create(rewriter, loc, kernelName, TypeRange{}, callOperands); + + // Replace the tensor result with the init buffer (DPS aliasing). + replaceOpWithBufferizedValues(rewriter, op, *initBuf); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public registration entry point +//===----------------------------------------------------------------------===// + +void registerMIPSBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, MIPSDialect * /*dialect*/) { + MatmulOp::attachInterface(*ctx); + }); +} + +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp new file mode 100644 index 000000000000..1f46719bd7e5 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp @@ -0,0 +1,55 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" + +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::MIPS; + +//===----------------------------------------------------------------------===// +// Inliner interface — allow MIPS ops to be inlined unconditionally. +//===----------------------------------------------------------------------===// + +namespace { +struct MIPSInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Dialect initialize +//===----------------------------------------------------------------------===// + +void MIPSDialect::initialize() { + addInterfaces(); + +#define GET_OP_LIST + addOperations< +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp.inc" + >(); +} + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h new file mode 100644 index 000000000000..aa99d0a63005 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h @@ -0,0 +1,28 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ +#define IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +// clang-format off +// MIPSDialect.h.inc is generated from MIPSOps.td via: +// --dialect=mips --gen-dialect-decls MIPSDialect.h.inc +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h.inc" // IWYU pragma: keep +// clang-format on + +namespace mlir::iree_compiler::IREE::MIPS { + +// Register external BufferizableOpInterface models for MIPS ops. +// Call this from registerIreeDialects() before bufferization runs. +void registerMIPSBufferizableOpInterfaceExternalModels( + mlir::DialectRegistry ®istry); + +} // namespace mlir::iree_compiler::IREE::MIPS + +#endif // IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp new file mode 100644 index 000000000000..a79f629f959e --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp @@ -0,0 +1,111 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::MIPS; + +//===----------------------------------------------------------------------===// +// MatmulOp — ReifyRankedShapedTypeOpInterface +// +// Returns the output shape [M, N] so IREE's dispatch formation can compute +// the workload when wrapping mips.matmul in a flow.dispatch.workgroups region. +//===----------------------------------------------------------------------===// + +LogicalResult MatmulOp::reifyResultShapes( + OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + // Result is always tensor. M from lhs dim 0, N from rhs dim 1. + reifiedReturnShapes.push_back({tensor::getMixedSize(b, getLoc(), getLhs(), 0), + tensor::getMixedSize(b, getLoc(), getRhs(), 1)}); + return success(); +} + +//===----------------------------------------------------------------------===// +// MatmulOp — MemoryEffectsOpInterface +// +// In the tensor domain, ops are nominally pure (tensors are values, not memory). +// However mips.matmul uses DPS — the init operand logically "carries" the +// result. We declare read on lhs/rhs and read+write on init so that alias +// analyses outside of bufferization correctly treat init as modified. +//===----------------------------------------------------------------------===// + +void MatmulOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getInitMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getInitMutable(), + SideEffects::DefaultResource::get()); +} + +//===----------------------------------------------------------------------===// +// MatmulOp — Verifier +//===----------------------------------------------------------------------===// + +LogicalResult MatmulOp::verify() { + auto shape = [](Value v) { + return cast(v.getType()).getShape(); + }; + auto elemTy = [](Value v) { + return cast(v.getType()).getElementType(); + }; + + // All operands must be 2-D tensors. + for (Value v : {getLhs(), getRhs(), getInit()}) { + if (cast(v.getType()).getRank() != 2) + return emitOpError("all operands must be 2-D ranked tensors"); + } + + // Dimension compatibility: lhs[M x K], rhs[K x N], init[M x N]. + // Only validate static dimensions; dynamic dims are checked at runtime. + auto compat = [](int64_t a, int64_t b) { + return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b; + }; + if (!compat(shape(getLhs())[0], shape(getInit())[0])) + return emitOpError("lhs dim 0 (M) must match init dim 0 (M)"); + if (!compat(shape(getLhs())[1], shape(getRhs())[0])) + return emitOpError("lhs dim 1 (K) must match rhs dim 0 (K)"); + if (!compat(shape(getRhs())[1], shape(getInit())[1])) + return emitOpError("rhs dim 1 (N) must match init dim 1 (N)"); + + // LHS and RHS element types must match. + if (elemTy(getLhs()) != elemTy(getRhs())) + return emitOpError("lhs and rhs element types must match"); + + // Supported element type combinations: + // f32 × f32 → f32 (standard float matmul) + // i8 × i8 → i32 (INT8 widening matmul) + Type lhsElem = elemTy(getLhs()); + Type outElem = elemTy(getInit()); + bool valid = (lhsElem == outElem) || + (lhsElem.isInteger(8) && outElem.isInteger(32)); + if (!valid) + return emitOpError( + "unsupported element type combination: lhs=") + << lhsElem << ", output=" << outElem + << "; supported: f32×f32→f32, i8×i8→i32"; + + // Result type must match init type. + if (getResult().getType() != getInit().getType()) + return emitOpError("result type must match init type"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen generated op definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h new file mode 100644 index 000000000000..dc2881aca2e3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h @@ -0,0 +1,25 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ +#define IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +// clang-format off + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h.inc" // IWYU pragma: export + +// clang-format on + +#endif // IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td new file mode 100644 index 000000000000..e87073fe3b99 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td @@ -0,0 +1,71 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MIPS_OPS +#define IREE_DIALECT_MIPS_OPS + +include "iree/compiler/Dialect/MIPS/IR/MIPSBase.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// mips.matmul — tensor-only semantic op +// +// This op exists exclusively in the tensor domain. It is eliminated during +// One-Shot Bufferize: the BufferizableOpInterface implementation allocates +// the output memref and emits func.call @mips_matmul(...) directly, so no +// memref form of this op ever exists in the IR. +//===----------------------------------------------------------------------===// + +def MIPS_MatmulOp : MIPS_Op<"matmul", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DestinationStyleOpInterface +]> { + let summary = "MIPS matrix multiplication (tensor domain): result = lhs * rhs"; + let description = [{ + Computes a 2-D matrix multiplication in the tensor domain using + destination-passing style (DPS). The caller provides an `init` tensor that + One-Shot Bufferize uses to determine the output buffer — typically + `bufferization.alloc_tensor` for a fresh allocation. + + The semantic is: `result[m, n] = sum_k(lhs[m, k] * rhs[k, n])`. + + This op is created by the Torch -> MIPS conversion pass from `torch.aten.mm` + and is eliminated entirely during bufferization: the BufferizableOpInterface + implementation emits `func.call @mips_matmul` directly with the bufferized + memref operands. No memref-form `mips.matmul` is ever produced. + + Example: + ```mlir + %result = mips.matmul %A, %B, %init + : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32> + ``` + }]; + + let arguments = (ins + AnyRankedTensor:$lhs, // [M x K] + AnyRankedTensor:$rhs, // [K x N] + AnyRankedTensor:$init // [M x N] — DPS destination (typically alloc_tensor) + ); + + let results = (outs AnyRankedTensor:$result); // [M x N] + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $init attr-dict `:` + type($lhs) `,` type($rhs) `,` type($init) `->` type($result) + }]; + + let extraClassDeclaration = [{ + // DestinationStyleOpInterface: init is the DPS output operand. + MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } + }]; + + let hasVerifier = 1; +} + +#endif // IREE_DIALECT_MIPS_OPS diff --git a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt index 4b124334e1ef..01cc625df2af 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt @@ -89,6 +89,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Dialect::MIPS::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::TensorExt::IR iree::compiler::Dialect::TensorExt::Transforms diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 6ba8a4bd8db2..f6a5467aee5c 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h" @@ -369,6 +370,9 @@ static bool isRootLikeOp(Operation *op) { return !isa(op); } + // MIPS: mips.matmul is a dispatch root (lowered to a custom C kernel call). + if (isa(op)) + return true; return isa(op); } diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h index c47ae6cb4368..78f1d3768f68 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h @@ -23,6 +23,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" @@ -50,6 +51,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { IREE::HAL::Loader::HALLoaderDialect, IREE::IO::Parameters::IOParametersDialect, IREE::LinalgExt::IREELinalgExtDialect, + IREE::MIPS::MIPSDialect, IREE::PCF::PCFDialect, IREE::Encoding::IREEEncodingDialect, IREE::Stream::StreamDialect, @@ -65,6 +67,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { registerCodegenInterfaces(registry); registerGlobalOptimizationInterfaces(registry); registerUKernelBufferizationInterface(registry); + IREE::MIPS::registerMIPSBufferizableOpInterfaceExternalModels(registry); // Register transform dialect extensions. registerTransformDialectPreprocessingExtension(registry); diff --git a/runtime/src/iree/builtins/mips/CMakeLists.txt b/runtime/src/iree/builtins/mips/CMakeLists.txt new file mode 100644 index 000000000000..ac1fbaea22b4 --- /dev/null +++ b/runtime/src/iree/builtins/mips/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Shared library containing the MIPS custom matmul kernels. +# +# matmul_kernel.c — RVV f32 (or scalar fallback) compute kernel; no IREE headers. +# matmul_kernel_i8.c — RVV INT8 widening compute kernel; no IREE headers. +# matmul_plugin.c — IREE HAL executable plugin interface for both kernels. +# +# The plugin is loaded at runtime via: +# iree-run-module --executable_plugin=libmips_matmul.so ... + +add_library(mips_matmul SHARED + matmul_kernel.c + matmul_kernel_i8.c + matmul_plugin.c +) + +target_include_directories(mips_matmul + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + PRIVATE + # For iree/hal/local/executable_plugin.h (standalone C header, no deps) + ${PROJECT_SOURCE_DIR}/runtime/src +) + +set_target_properties(mips_matmul PROPERTIES + C_VISIBILITY_PRESET default + POSITION_INDEPENDENT_CODE ON +) diff --git a/runtime/src/iree/builtins/mips/README.md b/runtime/src/iree/builtins/mips/README.md new file mode 100644 index 000000000000..fff6c4004a57 --- /dev/null +++ b/runtime/src/iree/builtins/mips/README.md @@ -0,0 +1,138 @@ +# MIPS Kernel Library + +This directory contains the hand-tuned kernel implementations for the +`mips` IREE dialect — a semantic dispatch layer that maps high-level tensor +operations to target-specific, optimized C kernels (currently RVV-vectorized +RISC-V matmul). + +## Source Files + +| File | Purpose | +|------|---------| +| `matmul_kernel.h` | Public API declaration for `my_matmul_kernel` | +| `matmul_kernel.c` | RVV-vectorized (or scalar fallback) compute kernel; **no IREE headers** | +| `matmul_plugin.c` | IREE HAL executable plugin interface (wraps `matmul_kernel.c`) | +| `rvv_standalone_test.c` | Standalone QEMU smoke-test (no IREE dependency) | + +### Design Principle + +`matmul_kernel.c` is intentionally free of IREE headers, making it usable +for three different build targets without modification: + +``` +matmul_kernel.c ──┬── (.o) baked into dispatch ELF at compile time (static) + ├── (.so) IREE plugin loaded via --executable_plugin (dynamic) + └── linked with rvv_standalone_test.c (QEMU unit test) +``` + +## Kernel ABI + +Matches the `func.call` emitted by `MIPSBufferizableOpInterface` after +decomposing 2-D memrefs with `memref.extract_strided_metadata` +(`{llvm.bareptr = true}`, so `memref` → `float*`): + +```c +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] + int64_t M, int64_t N, int64_t K); +``` + +Each 2-D matrix is passed as `(base_ptr, offset, row_stride, col_stride)`, +supporting arbitrary memory layouts (row-major, column-major, non-contiguous). + +## Building + +Assumes `IREE_SRC` = path to this repo, `CLANG` = `~/miniforge3/bin/clang`. + +### Standalone test binary (no IREE, no libc) + +```bash +# RISC-V RVV (run under QEMU) +clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -static -nostdlib -ffreestanding -nostdinc \ + -isystem ~/miniforge3/lib/clang/18/include \ + matmul_kernel.c rvv_standalone_test.c -o rvv_test +qemu-riscv64 -cpu rv64,v=true,vlen=512,elen=64,vext_spec=v1.0 ./rvv_test + +# x86 host (scalar fallback, with libc) +clang matmul_kernel.c rvv_standalone_test.c -O2 -o rvv_test_host && ./rvv_test_host +``` + +Expected output: +``` +=== rvv_matmul standalone test [RVV] +[1] A * I = A (4x4 row-major) +[2] 2x3 * 3x2 = 2x2 +[3] col-major strides +[4] non-zero base offset +PASSED (20 passed, 0 failed) +``` + +### Static object (.o) — baked into dispatch ELF + +```bash +clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -c -nostdinc -isystem ~/miniforge3/lib/clang/18/include \ + -I "${IREE_SRC}/runtime/src" \ + matmul_kernel.c -o matmul_kernel_riscv.o +``` + +### Dynamic plugin (.so) — loaded at runtime + +```bash +clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -fPIC -shared -nostdinc -nostdlib \ + -isystem ~/miniforge3/lib/clang/18/include \ + -fuse-ld=~/miniforge3/bin/ld.lld \ + -I "${IREE_SRC}/runtime/src" \ + matmul_kernel.c matmul_plugin.c -o librvv_matmul.so +``` + +## Integration with IREE + +``` +torch.aten.mm + ─[ConvertTorchToMIPSPass]──► mips.matmul (flow IR, tensor domain) + ─[One-Shot Bufferize]──────► func.call @my_matmul_kernel (buffers decomposed) + ─[iree-compile LLVMCPU]────► dispatch ELF inside .vmfb +``` + +`MIPSBufferizableOpInterface` handles bufferization by decomposing each 2-D +memref into `(base_ptr, offset, stride0, stride1)` via +`memref.extract_strided_metadata` and emitting the `func.call` directly. +No memref form of `mips.matmul` is ever produced in the IR. + +### Static Embedding (`--iree-mips-static-embedding`) + +Pass `--iree-mips-static-embedding` to `iree-compile`. The bufferizer tags +`my_matmul_kernel` with `{hal.import.static}`, causing the LLVMCPU backend +to emit a direct linker-resolved call. A custom `lld_wrapper.sh` appends +`matmul_kernel_riscv.o` to every dispatch link at compile time. + +- No `--executable_plugin` at runtime — kernel is inside the `.vmfb`. +- Requires `-Bsymbolic` in the lld invocation. Without it, lld generates + `R_RISCV_JUMP_SLOT` in `.rela.plt`; IREE's embedded ELF loader ignores + `.rela.plt` (only processes `.rela.dyn`), causing a segfault on first call. + +### Dynamic Loading (`--executable_plugin`) + +Without the flag, `my_matmul_kernel` is a HAL import table entry resolved at +runtime from the plugin `.so` via `iree_hal_executable_plugin_query`. + +```bash +iree-run-module --module=matmul.vmfb \ + --executable_plugin=librvv_matmul.so \ + --function=matmul_4x4 ... +``` + +## End-to-End Workflow Scripts + +See [`build_tools/riscv/`](../../../../../build_tools/riscv/) in the repo root: + +| Script | Description | +|--------|-------------| +| `setup_qemu_workflow.sh` | One-time setup: toolchain, QEMU, IREE host + RISC-V builds | +| `rvv_qemu_workflow_static.sh` | Static-embedding pipeline + QEMU run | +| `rvv_qemu_workflow_dynamic.sh` | Dynamic-plugin pipeline + QEMU run | diff --git a/runtime/src/iree/builtins/mips/matmul_kernel.c b/runtime/src/iree/builtins/mips/matmul_kernel.c new file mode 100644 index 000000000000..02eccb06ccbf --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_kernel.c @@ -0,0 +1,110 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// RVV (RISC-V Vector 1.0) f32 matmul kernel. +// +// This file is intentionally free of IREE headers so it can be: +// - Compiled to a .o and baked into the dispatch ELF (static embedding). +// - Linked into the IREE plugin .so alongside matmul_plugin.c. +// - Compiled standalone for unit tests (rvv_standalone_test.c). +// +// Vectorization strategy (RVV LMUL=m4): +// Outer loops: m (rows of A) and k (contraction axis). +// Inner loop : n (columns of B), vectorized with vsetvl_e32m4. +// Each vl-wide strip of C[m, n:n+vl] is accumulated across k before storing. +// Arbitrary strides handled via conditional vlse/vsse vs vle/vse. + +#include "matmul_kernel.h" + +#include + +#ifdef __riscv_vector +#include +#endif + +//===----------------------------------------------------------------------===// +// Internal compute kernel +//===----------------------------------------------------------------------===// + +#ifdef __riscv_vector + +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + + const int a_unit_col = (A_s1 == 1); + const int b_unit_col = (B_s1 == 1); + const int c_unit_col = (C_s1 == 1); + + for (int64_t m = 0; m < M; ++m) { + int64_t n = 0; + while (n < N) { + size_t vl = __riscv_vsetvl_e32m4((size_t)(N - n)); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t k = 0; k < K; ++k) { + float a_val = a_unit_col ? A[m * A_s0 + k] + : A[m * A_s0 + k * A_s1]; + vfloat32m4_t b_vec = + b_unit_col + ? __riscv_vle32_v_f32m4(&B[k * B_s0 + n], vl) + : __riscv_vlse32_v_f32m4( + &B[k * B_s0 + n * B_s1], + (ptrdiff_t)(B_s1 * (int64_t)sizeof(float)), vl); + acc = __riscv_vfmacc_vf_f32m4(acc, a_val, b_vec, vl); + } + + if (c_unit_col) + __riscv_vse32_v_f32m4(&C[m * C_s0 + n], acc, vl); + else + __riscv_vsse32_v_f32m4( + &C[m * C_s0 + n * C_s1], + (ptrdiff_t)(C_s1 * (int64_t)sizeof(float)), acc, vl); + n += (int64_t)vl; + } + } +} + +#else // scalar fallback + +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + for (int64_t m = 0; m < M; ++m) + for (int64_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (int64_t k = 0; k < K; ++k) + acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } +} + +#endif // __riscv_vector + +//===----------------------------------------------------------------------===// +// Public entry point +//===----------------------------------------------------------------------===// + +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + rvv_matmul_core(A, A_off, A_s0, A_s1, + B, B_off, B_s0, B_s1, + C, C_off, C_s0, C_s1, + M, N, K); +} diff --git a/runtime/src/iree/builtins/mips/matmul_kernel.h b/runtime/src/iree/builtins/mips/matmul_kernel.h new file mode 100644 index 000000000000..11cf87ff1726 --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_kernel.h @@ -0,0 +1,60 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Public API for the MIPS matmul kernel. +// +// ABI contract (must match MIPSBufferizableOpInterface.cpp / decomposeMemref2D): +// Each 2-D memref is decomposed into (base_ptr, offset, stride0, stride1). +// {llvm.bareptr = true} means memref → float* in C. +// index → int64_t on RV64. +// +// C signatures (15 args each): +// void my_matmul_kernel( +// float* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] +// float* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] +// float* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] +// int64_t M, int64_t N, int64_t K +// ); +// void my_matmul_kernel_i8( +// int8_t* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] i8 +// int8_t* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] i8 +// int32_t* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] i32 +// int64_t M, int64_t N, int64_t K +// ); + +#ifndef IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ +#define IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// 2-D f32 matmul: C = A * B, Destination-Passing Style. +// Supports arbitrary row/col strides and non-zero base offsets. +// RVV-vectorized on RISC-V targets; scalar fallback elsewhere. +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K); + +// 2-D INT8 matmul: C[M×N] (i32) = A[M×K] (i8) * B[K×N] (i8). +// C is zero-initialized by the caller before the call. +// Supports arbitrary row/col strides and non-zero base offsets. +// RVV-vectorized (widening i8→i16→i32) on RISC-V; scalar fallback elsewhere. +void my_matmul_kernel_i8( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K); + +#ifdef __cplusplus +} +#endif + +#endif // IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ diff --git a/runtime/src/iree/builtins/mips/matmul_kernel_i8.c b/runtime/src/iree/builtins/mips/matmul_kernel_i8.c new file mode 100644 index 000000000000..0ed667401dd7 --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_kernel_i8.c @@ -0,0 +1,146 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// RVV (RISC-V Vector 1.0) INT8 matmul kernel: C[i32] = A[i8] * B[i8]. +// +// This file is intentionally free of IREE headers so it can be: +// - Compiled to a .o and baked into the dispatch ELF (static embedding). +// - Linked into the IREE plugin .so alongside matmul_plugin.c. +// - Compiled standalone for unit tests. +// +// Vectorization strategy (RVV widening multiply): +// Outer loops: m (rows of A) and k (contraction axis). +// Inner loop : n (columns of B), vectorized. +// +// LMUL selection to match widening chain: +// i8 LMUL=m2 → VLMAX = (VLEN/8) * 2 = VLEN/4 +// i16 LMUL=m4 → VLMAX = (VLEN/16) * 4 = VLEN/4 (after first widening) +// i32 LMUL=m8 → VLMAX = (VLEN/32) * 8 = VLEN/4 (accumulator) +// All three LMULs yield the same VLMAX, so the same application vl is valid +// for all three element widths — no secondary vsetvl is needed per intrinsic. +// +// Per n-strip: +// acc[vl] (i32m8) = 0 +// for k in 0..K: +// a_val (i8 scalar) = A[m, k] +// b_i8 (i8m2 vec) = B[k, n:n+vl] +// b_i16 (i16m4 vec) = sign_extend(b_i8) +// acc += widen_macc(a_val, b_i16) // i16 * i16 → i32 +// C[m, n:n+vl] = acc +// +// The widening chain avoids intermediate overflow: +// i8 × i8 → i16 intermediate → accumulated into i32. + +#include "matmul_kernel.h" + +#include +#include + +#ifdef __riscv_vector +#include +#endif + +//===----------------------------------------------------------------------===// +// Internal compute kernel (RVV path) +//===----------------------------------------------------------------------===// + +#ifdef __riscv_vector + +static void rvv_matmul_i8_core( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + + const int a_unit_col = (A_s1 == 1); + const int b_unit_col = (B_s1 == 1); + const int c_unit_col = (C_s1 == 1); + + for (int64_t m = 0; m < M; ++m) { + int64_t n = 0; + while (n < N) { + // vl: element count for this n-strip. + // __riscv_vsetvl_e32m8 caps vl at VLMAX_i32m8 = VLEN/4, which equals + // VLMAX_i8m2 and VLMAX_i16m4 — all intrinsics below share this vl. + size_t vl = __riscv_vsetvl_e32m8((size_t)(N - n)); + + // Initialize i32 accumulator to zero. + vint32m8_t acc = __riscv_vmv_v_x_i32m8(0, vl); + + for (int64_t k = 0; k < K; ++k) { + // Scalar load from A[m, k]. + int8_t a_val = a_unit_col ? A[m * A_s0 + k] + : A[m * A_s0 + k * A_s1]; + + // Vector load B[k, n:n+vl] as i8. + vint8m2_t b_i8 = + b_unit_col + ? __riscv_vle8_v_i8m2((const int8_t *)&B[k * B_s0 + n], vl) + : __riscv_vlse8_v_i8m2( + (const int8_t *)&B[k * B_s0 + n * B_s1], + (ptrdiff_t)(B_s1 * (int64_t)sizeof(int8_t)), vl); + + // Sign-extend i8m2 → i16m4 (same element count, double the width). + vint16m4_t b_i16 = __riscv_vsext_vf2_i16m4(b_i8, vl); + + // Widening signed multiply-accumulate: + // acc[i] += sign_ext_32(a_val) * sign_ext_32(b_i16[i]) + // vwmacc.vx vd[i32m8], rs1[i16], vs2[i16m4] + acc = __riscv_vwmacc_vx_i32m8(acc, (int16_t)a_val, b_i16, vl); + } + + // Store accumulator to C[m, n:n+vl]. + if (c_unit_col) + __riscv_vse32_v_i32m8((int32_t *)&C[m * C_s0 + n], acc, vl); + else + __riscv_vsse32_v_i32m8( + (int32_t *)&C[m * C_s0 + n * C_s1], + (ptrdiff_t)(C_s1 * (int64_t)sizeof(int32_t)), acc, vl); + + n += (int64_t)vl; + } + } +} + +#else // scalar fallback + +static void rvv_matmul_i8_core( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + for (int64_t m = 0; m < M; ++m) + for (int64_t n = 0; n < N; ++n) { + int32_t acc = 0; + for (int64_t k = 0; k < K; ++k) + acc += (int32_t)A[m * A_s0 + k * A_s1] * + (int32_t)B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } +} + +#endif // __riscv_vector + +//===----------------------------------------------------------------------===// +// Public entry point +//===----------------------------------------------------------------------===// + +void my_matmul_kernel_i8( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + rvv_matmul_i8_core(A, A_off, A_s0, A_s1, + B, B_off, B_s0, B_s1, + C, C_off, C_s0, C_s1, + M, N, K); +} diff --git a/runtime/src/iree/builtins/mips/matmul_plugin.c b/runtime/src/iree/builtins/mips/matmul_plugin.c new file mode 100644 index 000000000000..5d5765351bf5 --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_plugin.c @@ -0,0 +1,146 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// IREE Executable Plugin interface for the MIPS matmul kernels. +// +// Registers both the f32 and INT8 kernels: +// my_matmul_kernel — f32 matmul +// my_matmul_kernel_i8 — INT8 (i8 inputs, i32 accumulator) matmul +// +// Build as a shared library alongside matmul_kernel.c and matmul_kernel_i8.c: +// clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ +// -O2 -fPIC -shared -nostdinc ... \ +// matmul_kernel.c matmul_kernel_i8.c matmul_plugin.c -o librvv_matmul.so + +#include "matmul_kernel.h" + +#include "iree/hal/local/executable_plugin.h" + +//===----------------------------------------------------------------------===// +// Import wrapper +//===----------------------------------------------------------------------===// +// The HAL plugin dispatch table expects functions with the signature: +// int fn(void *params_ptr, void *context, void *reserved) +// where params_ptr points to a packed struct matching the func.call ABI +// emitted by MIPSBufferizableOpInterface::bufferize() → decomposeMemref2D(). + +// ── f32 kernel args ────────────────────────────────────────────────────────── +typedef struct { + float *A; + int64_t A_off, A_s0, A_s1; + float *B; + int64_t B_off, B_s0, B_s1; + float *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} matmul_kernel_args_t; + +static int matmul_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const matmul_kernel_args_t *a = (const matmul_kernel_args_t *)params_ptr; + my_matmul_kernel(a->A, a->A_off, a->A_s0, a->A_s1, + a->B, a->B_off, a->B_s0, a->B_s1, + a->C, a->C_off, a->C_s0, a->C_s1, + a->M, a->N, a->K); + return 0; +} + +// ── INT8 kernel args ────────────────────────────────────────────────────────── +typedef struct { + int8_t *A; + int64_t A_off, A_s0, A_s1; + int8_t *B; + int64_t B_off, B_s0, B_s1; + int32_t *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} matmul_i8_kernel_args_t; + +static int matmul_i8_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const matmul_i8_kernel_args_t *a = (const matmul_i8_kernel_args_t *)params_ptr; + my_matmul_kernel_i8(a->A, a->A_off, a->A_s0, a->A_s1, + a->B, a->B_off, a->B_s0, a->B_s1, + a->C, a->C_off, a->C_s0, a->C_s1, + a->M, a->N, a->K); + return 0; +} + +//===----------------------------------------------------------------------===// +// Plugin lifecycle +//===----------------------------------------------------------------------===// + +static iree_hal_executable_plugin_status_t plugin_load( + const iree_hal_executable_plugin_environment_v0_t *environment, + size_t param_count, + const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { + (void)environment; + (void)param_count; + (void)params; + *out_self = NULL; + return iree_hal_executable_plugin_ok_status(); +} + +static void plugin_unload(void *self) { (void)self; } + +static iree_hal_executable_plugin_status_t plugin_resolve( + void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, + iree_hal_executable_plugin_resolution_t *out_resolution) { + (void)self; + *out_resolution = 0; + bool any_required_not_found = false; + + for (size_t i = 0; i < params->count; ++i) { + if (params->out_fn_ptrs[i]) continue; + const char *name = params->symbol_names[i]; + bool optional = iree_hal_executable_plugin_import_is_optional(name); + if (optional) ++name; + + if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { + params->out_fn_ptrs[i] = matmul_kernel_import; + params->out_fn_contexts[i] = NULL; + } else if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel_i8") == 0) { + params->out_fn_ptrs[i] = matmul_i8_kernel_import; + params->out_fn_contexts[i] = NULL; + } else { + if (!optional) any_required_not_found = true; + } + } + + return any_required_not_found + ? iree_hal_executable_plugin_status_from_code( + IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) + : iree_hal_executable_plugin_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Plugin query entry point +//===----------------------------------------------------------------------===// + +IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** +iree_hal_executable_plugin_query( + iree_hal_executable_plugin_version_t max_version, void *reserved) { + static const iree_hal_executable_plugin_header_t header = { + .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, + .name = "mips_matmul", + .description = "RISC-V RVV 1.0 matmul kernel plugin (f32 + INT8)", + .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, + .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, + }; + static const iree_hal_executable_plugin_v0_t plugin = { + .header = &header, + .load = plugin_load, + .unload = plugin_unload, + .resolve = plugin_resolve, + }; + return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST + ? (const iree_hal_executable_plugin_header_t **)&plugin + : NULL; +} diff --git a/runtime/src/iree/builtins/mips/rvv_standalone_test.c b/runtime/src/iree/builtins/mips/rvv_standalone_test.c new file mode 100644 index 000000000000..1b68de181991 --- /dev/null +++ b/runtime/src/iree/builtins/mips/rvv_standalone_test.c @@ -0,0 +1,191 @@ +// Standalone QEMU smoke-test for my_matmul_kernel. +// +// Build for QEMU (no libc): +// clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ +// -O2 -static -nostdlib -ffreestanding \ +// matmul_kernel.c rvv_standalone_test.c -o rvv_test +// qemu-riscv64 -cpu rv64,v=true,vlen=128,elen=64 ./rvv_test +// +// Build for host validation (x86, scalar fallback): +// clang matmul_kernel.c rvv_standalone_test.c -O2 -o rvv_test_host +// ./rvv_test_host + +#include "matmul_kernel.h" + +#include +#include + +// ── I/O and exit ────────────────────────────────────────────────────────────── +// RISC-V target: raw ecall (no libc dependency for -nostdlib build). +// Host (x86) target: libc stdio. + +#ifdef __riscv + +static long _rv_syscall1(long nr, long a0) { + register long _a7 __asm__("a7") = nr; + register long _a0 __asm__("a0") = a0; + __asm__ volatile("ecall" : "+r"(_a0) : "r"(_a7) : "memory"); + return _a0; +} + +static long _rv_syscall3(long nr, long a0, long a1, long a2) { + register long _a7 __asm__("a7") = nr; + register long _a0 __asm__("a0") = a0; + register long _a1 __asm__("a1") = a1; + register long _a2 __asm__("a2") = a2; + __asm__ volatile("ecall" : "+r"(_a0) : "r"(_a7), "r"(_a1), "r"(_a2) + : "memory"); + return _a0; +} + +static void sys_write(const char *buf, size_t len) { + _rv_syscall3(64 /*SYS_write*/, 1 /*stdout*/, (long)buf, (long)len); +} + +__attribute__((noreturn)) static void sys_exit(int code) { + _rv_syscall1(94 /*SYS_exit_group*/, (long)code); + __builtin_unreachable(); +} + +void _start(void); // forward-declare; entry point at bottom + +#else // host x86 + +#include +#include + +static void sys_write(const char *buf, size_t len) { + fwrite(buf, 1, len, stdout); +} + +__attribute__((noreturn)) static void sys_exit(int code) { exit(code); } + +#endif // __riscv + +// ── Minimal print helpers ───────────────────────────────────────────────────── + +static void print(const char *s) { + size_t n = 0; + while (s[n]) ++n; + sys_write(s, n); +} + +static void print_float(float v) { + if (v < 0.0f) { print("-"); v = -v; } + int whole = (int)v; + int frac = (int)((v - (float)whole) * 10000.0f + 0.5f); + char buf[20]; + int i = 19; + buf[i--] = '\0'; + for (int j = 0; j < 4; ++j) { buf[i--] = (char)('0' + frac % 10); frac /= 10; } + buf[i--] = '.'; + if (whole == 0) { buf[i--] = '0'; } + else { while (whole > 0) { buf[i--] = (char)('0' + whole % 10); whole /= 10; } } + print(&buf[i + 1]); +} + +// ── Test harness ────────────────────────────────────────────────────────────── + +static int tests_passed = 0; +static int tests_failed = 0; + +static float _fabsf(float x) { return x < 0.0f ? -x : x; } + +static void check(const char *name, float got, float expected) { + if (_fabsf(got - expected) < 1e-4f) { + ++tests_passed; + } else { + ++tests_failed; + print(" FAIL "); print(name); + print(" got="); print_float(got); + print(" expected="); print_float(expected); print("\n"); + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +// Test 1: A * I = A (4×4, row-major) +static void test_identity(void) { + print("[1] A * I = A (4x4 row-major)\n"); + float A[16] = { 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16 }; + float I[16] = { 1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1 }; + float C[16] = {0}; + my_matmul_kernel(A, 0,4,1, I, 0,4,1, C, 0,4,1, 4,4,4); + for (int i = 0; i < 16; ++i) check("A*I", C[i], A[i]); +} + +// Test 2: 2x3 * 3x2 = 2x2 → [[58,64],[139,154]] +static void test_2x3x2(void) { + print("[2] 2x3 * 3x2 = 2x2\n"); + float A[6] = {1,2,3, 4,5,6}; + float B[6] = {7,8, 9,10, 11,12}; + float C[4] = {0}; + my_matmul_kernel(A, 0,3,1, B, 0,2,1, C, 0,2,1, 2,2,3); + check("C[0,0]", C[0], 58.0f); + check("C[0,1]", C[1], 64.0f); + check("C[1,0]", C[2], 139.0f); + check("C[1,1]", C[3], 154.0f); +} + +// Test 3: column-major strides +static void test_col_major(void) { + print("[3] col-major strides\n"); + // A[2×3] col-major: stored [1,4, 2,5, 3,6], s0=1, s1=2 + float A[6] = {1,4, 2,5, 3,6}; + // B[3×2] col-major: stored [7,9,11, 8,10,12], s0=1, s1=3 + float B[6] = {7,9,11, 8,10,12}; + float C[4] = {0}; + my_matmul_kernel(A, 0,1,2, B, 0,1,3, C, 0,1,2, 2,2,3); + // C[m,n] stored at C[m + n*2] + check("C[0,0]", C[0], 58.0f); + check("C[1,0]", C[1], 139.0f); + check("C[0,1]", C[2], 64.0f); + check("C[1,1]", C[3], 154.0f); +} + +// Test 4: non-zero base offset +static void test_offset(void) { + print("[4] non-zero base offset\n"); + float A[8] = {99,99,99,99, 1,0, 0,1}; + float B[8] = {99,99,99,99, 3,0, 0,5}; + float C[8] = {0}; + my_matmul_kernel(A, 4,2,1, B, 4,2,1, C, 4,2,1, 2,2,2); + check("C[0,0]", C[4], 3.0f); + check("C[0,1]", C[5], 0.0f); + check("C[1,0]", C[6], 0.0f); + check("C[1,1]", C[7], 5.0f); +} + +// ── Entry point ─────────────────────────────────────────────────────────────── + +#ifdef __riscv +void _start(void) { +#else +int main(void) { +#endif + print("=== rvv_matmul standalone test"); +#ifdef __riscv_vector + print(" [RVV]\n"); +#else + print(" [scalar]\n"); +#endif + + test_identity(); + test_2x3x2(); + test_col_major(); + test_offset(); + + print("\n"); + print(tests_failed == 0 ? "PASSED" : "FAILED"); + print(" ("); + char b[4]; b[1] = '\0'; + b[0] = '0' + (char)tests_passed; print(b); + print(" passed, "); + b[0] = '0' + (char)tests_failed; print(b); + print(" failed)\n"); + + sys_exit(tests_failed == 0 ? 0 : 1); +#ifndef __riscv + return 0; +#endif +}