From 3c8b8ee011870984f99d884cdd4915e3fd93b65e Mon Sep 17 00:00:00 2001 From: "yan.wang" Date: Wed, 3 Jun 2026 14:50:48 +0800 Subject: [PATCH] [BACKEND][EVAS] Integrete Evas Backend --- CMakeLists.txt | 6 +- third_party/evas/.gitignore | 15 + third_party/evas/CMakeLists.txt | 59 + third_party/evas/backend/__init__.py | 1 + third_party/evas/backend/compiler.py | 1 + third_party/evas/backend/driver.py | 2 + third_party/evas/backend/evas.py | 1113 +++++++++++++++++ third_party/evas/backend/include/epu/memory.h | 296 +++++ .../evas/backend/include/evas/helper_hpe.h | 269 ++++ third_party/evas/backend/name.conf | 1 + third_party/evas/backend/triton_evas.cc | 10 + third_party/evas/backend/utils.py | 67 + .../evas/bin/evas-triton-opt/CMakeLists.txt | 30 + .../bin/evas-triton-opt/evas-triton-opt.cpp | 43 + .../TritonToEvas/TritonToEvasPipeline.h | 23 + .../evas/Dialect/Linalg/IR/CMakeLists.txt | 8 + .../evas/Dialect/Linalg/IR/LinalgEnumsExt.td | 21 + .../evas/Dialect/Linalg/IR/LinalgOpsExt.h | 19 + .../evas/Dialect/Linalg/IR/LinalgOpsExt.td | 95 ++ .../evas/Transform/Linalg/CMakeLists.txt | 3 + .../evas/Transform/Linalg/MemoryAlloc.h | 111 ++ .../include/evas/Transform/Linalg/Passes.h | 33 + .../include/evas/Transform/Linalg/Passes.td | 171 +++ .../Transform/Linalg/RegionMemAllocator.h | 268 ++++ .../Conversion/TritonToEvas/CMakeLists.txt | 20 + .../TritonToEvas/TritonArithToLinalgNamed.cpp | 277 ++++ .../TritonToEvas/TritonToEvasPipeline.cpp | 66 + .../evas/lib/Dialect/Linalg/IR/CMakeLists.txt | 13 + .../lib/Dialect/Linalg/IR/LinalgOpsExt.cpp | 271 ++++ .../evas/lib/Transform/Linalg/Bufferize.cpp | 60 + .../evas/lib/Transform/Linalg/CMakeLists.txt | 21 + .../lib/Transform/Linalg/DoubleBuffer.cpp | 233 ++++ .../Transform/Linalg/EncapsulateLinalgOp.cpp | 583 +++++++++ .../lib/Transform/Linalg/InsertDeallocOp.cpp | 93 ++ .../Linalg/MaterializeAnnotation.cpp | 119 ++ .../evas/lib/Transform/Linalg/MemoryAlloc.cpp | 139 ++ .../lib/Transform/Linalg/MemoryAllocPass.cpp | 141 +++ .../Transform/Linalg/MemoryPromotionPass.cpp | 161 +++ .../Transform/Linalg/RegionMemAllocator.cpp | 401 ++++++ .../RemoveLoopIterArgsWithMemrefType.cpp | 247 ++++ .../Linalg/RemoveRedundencyCopyPass.cpp | 215 ++++ .../lib/Transform/Linalg/RemoveScalar.cpp | 154 +++ .../lib/Transform/Linalg/RewriteDataType.cpp | 329 +++++ .../Linalg/RewriteFuncOpArgsType.cpp | 103 ++ .../lib/Transform/Linalg/SetDeviceInfo.cpp | 302 +++++ .../Transform/Linalg/SetMemRefScopePass.cpp | 294 +++++ .../Transform/Linalg/SplitComputationalOp.cpp | 664 ++++++++++ .../patches/triton-shared-llvm22-compat.patch | 75 ++ 48 files changed, 7643 insertions(+), 3 deletions(-) create mode 100644 third_party/evas/.gitignore create mode 100644 third_party/evas/CMakeLists.txt create mode 100644 third_party/evas/backend/__init__.py create mode 100644 third_party/evas/backend/compiler.py create mode 100644 third_party/evas/backend/driver.py create mode 100644 third_party/evas/backend/evas.py create mode 100644 third_party/evas/backend/include/epu/memory.h create mode 100644 third_party/evas/backend/include/evas/helper_hpe.h create mode 100644 third_party/evas/backend/name.conf create mode 100644 third_party/evas/backend/triton_evas.cc create mode 100644 third_party/evas/backend/utils.py create mode 100644 third_party/evas/bin/evas-triton-opt/CMakeLists.txt create mode 100644 third_party/evas/bin/evas-triton-opt/evas-triton-opt.cpp create mode 100644 third_party/evas/include/evas/Conversion/TritonToEvas/TritonToEvasPipeline.h create mode 100644 third_party/evas/include/evas/Dialect/Linalg/IR/CMakeLists.txt create mode 100644 third_party/evas/include/evas/Dialect/Linalg/IR/LinalgEnumsExt.td create mode 100644 third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.h create mode 100644 third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.td create mode 100644 third_party/evas/include/evas/Transform/Linalg/CMakeLists.txt create mode 100644 third_party/evas/include/evas/Transform/Linalg/MemoryAlloc.h create mode 100644 third_party/evas/include/evas/Transform/Linalg/Passes.h create mode 100644 third_party/evas/include/evas/Transform/Linalg/Passes.td create mode 100644 third_party/evas/include/evas/Transform/Linalg/RegionMemAllocator.h create mode 100644 third_party/evas/lib/Conversion/TritonToEvas/CMakeLists.txt create mode 100644 third_party/evas/lib/Conversion/TritonToEvas/TritonArithToLinalgNamed.cpp create mode 100644 third_party/evas/lib/Conversion/TritonToEvas/TritonToEvasPipeline.cpp create mode 100644 third_party/evas/lib/Dialect/Linalg/IR/CMakeLists.txt create mode 100644 third_party/evas/lib/Dialect/Linalg/IR/LinalgOpsExt.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/Bufferize.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/CMakeLists.txt create mode 100644 third_party/evas/lib/Transform/Linalg/DoubleBuffer.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/EncapsulateLinalgOp.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/InsertDeallocOp.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/MaterializeAnnotation.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/MemoryAlloc.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/MemoryAllocPass.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/MemoryPromotionPass.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/RegionMemAllocator.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/RemoveLoopIterArgsWithMemrefType.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/RemoveRedundencyCopyPass.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/RemoveScalar.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/RewriteDataType.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/RewriteFuncOpArgsType.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/SetDeviceInfo.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/SetMemRefScopePass.cpp create mode 100644 third_party/evas/lib/Transform/Linalg/SplitComputationalOp.cpp create mode 100644 third_party/evas/patches/triton-shared-llvm22-compat.patch diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f7f344e12..51604f0f42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,7 +231,7 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) if(NOT MSVC) if(FLAGTREE_BACKEND) - if(FLAGTREE_BACKEND MATCHES "^(enflame|hcu|thrive)$") + if(FLAGTREE_BACKEND MATCHES "^(enflame|hcu|thrive|evas)$") # Suppress visibility warnings in gluon_ir.cc (GCC 13+ -Wattributes on pybind11 hidden types) # Also suppress -Wcomment for generated TritonGPUAttrDefs.h.inc (ASCII diagrams in TableGen output) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-attributes -Wno-comment") @@ -259,7 +259,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files # link_directories(${LLVM_LIBRARY_DIR}) -if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu|tsingmicro|enflame|thrive)$") +if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu|tsingmicro|enflame|thrive|evas)$") include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files add_subdirectory(include) @@ -554,7 +554,7 @@ find_package(Threads REQUIRED) add_subdirectory(third_party/f2reduce) -if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND MATCHES "^(aipu|tsingmicro|enflame|thrive)$") +if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND MATCHES "^(aipu|tsingmicro|enflame|thrive|evas)$") add_subdirectory(bin) if(FLAGTREE_TLE) flagtree_add_tle_generated_header_dependencies() diff --git a/third_party/evas/.gitignore b/third_party/evas/.gitignore new file mode 100644 index 0000000000..64ea787472 --- /dev/null +++ b/third_party/evas/.gitignore @@ -0,0 +1,15 @@ +__pycache__/ +.cache/ +*.py[cod] +*.egg-info/ +.pytest_cache/ +.mypy_cache/ +build/ +dist/ +*.so +*.elf +*.mlir +*.log +accuracy_result.json +runtime/ev_torch/build/ +runtime/ev_torch/ev_torch/lib/ diff --git a/third_party/evas/CMakeLists.txt b/third_party/evas/CMakeLists.txt new file mode 100644 index 0000000000..4dce593ebb --- /dev/null +++ b/third_party/evas/CMakeLists.txt @@ -0,0 +1,59 @@ +cmake_minimum_required(VERSION 3.20) + +project(flagtree_evas LANGUAGES C CXX) + +include(FetchContent) + +set(EVAS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +set(EVAS_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") + +set(TRITON_SHARED_GIT_REPOSITORY + "https://github.com/facebookincubator/triton-shared.git") +set(TRITON_SHARED_GIT_TAG "e0c513317d9e7838b00730fd494e3372400e93db") +set(TRITON_SHARED_SOURCE_DIR "${EVAS_BINARY_DIR}/_deps/evas_triton_shared-src") +set(TRITON_SHARED_BINARY_DIR "${EVAS_BINARY_DIR}/third_party/triton_shared") +set(TRITON_SHARED_LLVM_COMPAT_PATCH + "${EVAS_SOURCE_DIR}/patches/triton-shared-llvm22-compat.patch") + +if(NOT EXISTS "${TRITON_SHARED_LLVM_COMPAT_PATCH}") + message(FATAL_ERROR "missing triton_shared compatibility patch: ${TRITON_SHARED_LLVM_COMPAT_PATCH}") +endif() + +if(NOT TARGET TritonSharedAnalysis) + find_package(Git REQUIRED) + FetchContent_Declare( + evas_triton_shared + GIT_REPOSITORY "${TRITON_SHARED_GIT_REPOSITORY}" + GIT_TAG "${TRITON_SHARED_GIT_TAG}" + PATCH_COMMAND + "${CMAKE_COMMAND}" -E chdir /bin/bash -c + "\"${GIT_EXECUTABLE}\" apply --check \"${TRITON_SHARED_LLVM_COMPAT_PATCH}\" && \"${GIT_EXECUTABLE}\" apply \"${TRITON_SHARED_LLVM_COMPAT_PATCH}\" || \"${GIT_EXECUTABLE}\" apply --reverse --check \"${TRITON_SHARED_LLVM_COMPAT_PATCH}\"") + FetchContent_GetProperties(evas_triton_shared) + if(NOT evas_triton_shared_POPULATED) + FetchContent_Populate(evas_triton_shared) + endif() + + set(TRITON_SHARED_SOURCE_DIR "${evas_triton_shared_SOURCE_DIR}") + add_subdirectory("${TRITON_SHARED_SOURCE_DIR}" "${TRITON_SHARED_BINARY_DIR}") +endif() + +set(TRITON_SHARED_TOOLS_DIR "${TRITON_SHARED_SOURCE_DIR}/tools/triton-shared-opt") + +include_directories(${EVAS_SOURCE_DIR}/include) +include_directories(${EVAS_BINARY_DIR}/include) +include_directories(${EVAS_SOURCE_DIR}/backend/include) +include_directories(${TRITON_SHARED_SOURCE_DIR}/include) +include_directories(${TRITON_SHARED_BINARY_DIR}/include) + +add_subdirectory(include/evas/Dialect/Linalg/IR) +add_subdirectory(include/evas/Transform/Linalg) + +add_subdirectory(lib/Dialect/Linalg/IR) +add_subdirectory(lib/Transform/Linalg) +add_subdirectory(lib/Conversion/TritonToEvas) +add_subdirectory(bin/evas-triton-opt) + +if(TRITON_BUILD_PYTHON_MODULE AND COMMAND add_triton_plugin) + add_triton_plugin(TritonEVAS ${EVAS_SOURCE_DIR}/backend/triton_evas.cc) + target_link_libraries(TritonEVAS PRIVATE Python3::Module pybind11::headers) +endif() diff --git a/third_party/evas/backend/__init__.py b/third_party/evas/backend/__init__.py new file mode 100644 index 0000000000..45e4714937 --- /dev/null +++ b/third_party/evas/backend/__init__.py @@ -0,0 +1 @@ +# EVAS Triton backend package. diff --git a/third_party/evas/backend/compiler.py b/third_party/evas/backend/compiler.py new file mode 100644 index 0000000000..f3b0e0c114 --- /dev/null +++ b/third_party/evas/backend/compiler.py @@ -0,0 +1 @@ +from triton.backends.evas.evas import EvasBackend diff --git a/third_party/evas/backend/driver.py b/third_party/evas/backend/driver.py new file mode 100644 index 0000000000..93e841d16e --- /dev/null +++ b/third_party/evas/backend/driver.py @@ -0,0 +1,2 @@ +from triton.backends.evas.evas import EvasDriver as ActiveDriver + diff --git a/third_party/evas/backend/evas.py b/third_party/evas/backend/evas.py new file mode 100644 index 0000000000..6dbde2a2f2 --- /dev/null +++ b/third_party/evas/backend/evas.py @@ -0,0 +1,1113 @@ +##### Add EVBackend here +import hashlib +from tabnanny import check +import os, tempfile +import sysconfig +import numpy as np +import importlib.util +from pathlib import Path +import site +import sys +import logging +import re +import subprocess +import functools + +from triton.runtime.cache import get_dump_manager, get_cache_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import BaseBackend, GPUTarget + + +from .utils import run_command, _cache_key + +from triton._C import libtriton +from triton._C.libtriton import ir, passes +try: + from triton._C.libtriton import triton_shared +except ImportError: + triton_shared = None +from dataclasses import dataclass +from typing import Any, Tuple, Dict +from types import ModuleType + +CLUSTER_NUM = 4 +CORE_NUM = 4 +# EVAS Driver +# -------------------- Launcher ---------------------------- +def _ty_to_cpp(ty): + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + "*fp32": "float*", + "*fp16": "_Float16*", + "*i1": "int8_t*", + "*i8": "int8_t*", + "*i16": "int16_t*", + "*i32": "int32_t*", + "*i64": "int64_t*", + "*u1" : "int8_t*", + "*u8" : "int8_t*", + "*bf16": "bfloat16_t*" + }[ty] + + +def compile_module(signature, constants, launcher_src, arg_names, launch_call_placeholder): + params = list(signature.items()) + + def gen_args(signature, args): + # kernel_arg_decls = [_ty_to_cpp(ty) for i, ty in signature.items() if i not in constants and ty[0] == "*"] + args_desc = [] + for index, (arg_idx, ty) in enumerate(params): + arg = args[index] + is_output = arg_names[arg_idx].startswith("out") + is_pointer = ty[0] == "*" + is_tensor = hasattr(arg, 'numel') and hasattr(arg, 'element_size') + args_desc.append( + { + "tensor": arg, + "size": arg.numel() * arg.element_size() if is_pointer and is_tensor else None, + "arg_idx": arg_idx, + "dtype": arg.dtype if is_pointer and is_tensor else None, + "arg_type": _ty_to_cpp(ty), + "is_pointer": is_pointer, + "addr": arg.data_ptr() if is_pointer and is_tensor else None, + "is_output": is_output, + "is_cpu": arg.device.type == "cpu" if is_pointer and is_tensor else None + } + ) + return args_desc + + def normalize_args(args): + if len(args) == len(params): + return list(args) + return [arg for i, arg in enumerate(args) if i in signature.keys()] + + def pack_launcher(args_desc, launcher_src, kernel_name): + packed = launcher_src + host_to_device_code = ' '.join(f'void* host_ptr{i} = ptr_info{i}.dev_ptr; {desc["arg_type"]} device_ptr{i}; evMalloc((void **)&device_ptr{i}, {desc["size"]}, mem_affinitymap); evMemcpy(device_ptr{i}, host_ptr{i}, {desc["size"]}, evMemcpyHostToDevice, mem_affinitymap); printf("[LOG] Copy from host %p to device %p\\n", host_ptr{i}, device_ptr{i}); ptr_info{i}.dev_ptr = (void*)device_ptr{i};' for i, desc in enumerate(args_desc) if desc['is_pointer'] and desc['is_cpu']) + device_to_host_code = ' '.join(f'evMemcpy(host_ptr{i}, device_ptr{i}, {desc["size"]}, evMemcpyDeviceToHost, mem_affinitymap); printf("[LOG] Copy from device %p to host %p\\n", device_ptr{i}, host_ptr{i});' for i, desc in enumerate(args_desc) if desc['is_pointer'] and desc['is_output'] and desc['is_cpu']) + return packed.replace("HOST_TO_DEVICE_PLACE_HOLDER", host_to_device_code).replace("DEVICE_TO_HOST_PLACE_HOLDER", device_to_host_code).replace(launch_call_placeholder, kernel_name) + + + def launch( + gridX, + gridY, + gridZ, + stream, + cu_function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + *args, + ): + asm_src = cu_function + kernel_name = kernel_metadata[6] + args = normalize_args(args) + args_desc = gen_args(signature, args) + src = pack_launcher(args_desc, launcher_src, kernel_name) + module_key = hashlib.sha256(src.encode("utf-8") + asm_src).hexdigest() + name = f"__triton_ref_epu_kernel_launcher_{module_key[:16]}" + src = src.replace("MODULE_NAME_PLACEHOLDER", name) + key = hashlib.sha256(src.encode("utf-8") + asm_src).hexdigest() + cache = get_cache_manager(key) + filename = f"{name}.so" + cache_path = cache.get_file(filename) + + if cache_path is None: + module = compile_module_from_src(src.encode("utf-8")) + cache.put(src, f"{name}.cc") + cache_path = cache.put(module, filename, binary=True) + + # Load and launch the compiled kernel. + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.elf") + Path(asm_src_path).write_bytes(asm_src) + return mod.launch(gridX, gridY, gridZ, asm_src_path, + kernel_metadata, launch_metadata, + launch_enter_hook, launch_exit_hook, + *args) + + return launch + +def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return _ty_to_cpp(ty) + +def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + +def _generate_launcher(constants, signature, kernel_name): + args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiisOOOO" + args_format + + args_list = ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + + params = [(i, ty) for i, ty in signature.items() if i not in constants] + # kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"void*" for i, ty in params) + # kernel_arg_decls += ', ' if kernel_arg_decls else '' + + # kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"arg{i}" for i, ty in params) + # kernel_parameters += ', ' if kernel_parameters else '' + + return f""" +#include +#include +#include +#include +#include +#include +#include +#include + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + const char *kernel_elf_path = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &kernel_elf_path, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook, {args_list})) {{ + return NULL; + }} + + // [CPULauncher-specific]: We don't need the metadata below but just put them + // here anyway to be consistent with others. + // This will make updating the driver easier in the future. + + // int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + // if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + // PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + // return NULL; + // }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + void *input_buf; + unsigned int file_size; + evModule_t module; + const char *kernel_name = "{kernel_name}"; + const uint64_t mem_affinitymap = evAffinityMapDefault; + const uint64_t kernel_affinitymap = evAffinityMapDefault; + checkErrors(evSetDevice(0)); + ReadBinFile(kernel_elf_path, &input_buf, &file_size); + // Logging before evModuleLoadData + printf("[LOG] Calling evModuleLoadData\\n"); + checkErrors(evModuleLoadData(&module, input_buf, file_size)); + printf("[LOG] Finished evModuleLoadData\\n"); + + evFunc hfunc; + + // Logging before evModuleGetFunction + printf("[LOG] Calling evModuleGetFunction\\n"); + checkErrors(evModuleGetFunction(&hfunc, module, kernel_name)); + printf("[LOG] Finished evModuleGetFunction\\n"); + + // Logging before checkNotEqual + printf("[LOG] Checking if hfunc is not NULL\\n"); + checkNotEqual(hfunc, (evFunc)NULL); + printf("[LOG] hfunc is valid\\n"); + + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in params])}; + + HOST_TO_DEVICE_PLACE_HOLDER + int grid_args[3] = {{gridX, gridY, gridZ}}; + printf("[LOG] grid_args: %d, %d, %d\\n", grid_args[0], grid_args[1], grid_args[2]); + void *kernel_args[{(len(params) + 3) * 2 + 1}] = {{ {", ".join([f"(void *)&(ptr_info{i}.dev_ptr), (void*)8" if ty[0]=="*" else f"(void*)(&_arg{i}), (void*)8" for i, ty in params])}, {", ".join([f"(void*)(&(grid_args[{i}])), (void*)8" for i in range(3)])}}}; + kernel_args[{(len(params) + 3) * 2}] = NULL; + // Logging before evConfigureCall + printf("[LOG] Calling evConfigureCall\\n"); + checkErrors(evConfigureCall({CLUSTER_NUM}, {CORE_NUM}, 0, kernel_affinitymap, evKernelFlushCache)); + printf("[LOG] Finished evConfigureCall with ClusterNum: %d, CoreNum: %d\\n", {CLUSTER_NUM}, {CORE_NUM}); + + // Logging before evLaunchKernel + printf("[LOG] Calling evLaunchKernel\\n"); + checkErrors(evLaunchKernel(&module, kernel_name, kernel_args)); + printf("[LOG] Finished evLaunchKernel\\n"); + + DEVICE_TO_HOST_PLACE_HOLDER + + // Logging before evModuleUnload + printf("[LOG] Calling evModuleUnload\\n"); + checkErrors(evModuleUnload(module)); + printf("[LOG] Finished evModuleUnload\\n"); + // Logging before evDeviceSynchronize + printf("[LOG] Calling evDeviceSynchronize\\n"); + evDeviceSynchronize(); + printf("[LOG] Finished evDeviceSynchronize\\n"); + if (PyErr_Occurred()) {{ + return NULL; + }} + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"MODULE_NAME_PLACEHOLDER\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit_MODULE_NAME_PLACEHOLDER(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + +def get_sys_path(bin_name: str): + # Get the path to the system clang++ + path = shutil.which(bin_name) + if path is None: + raise RuntimeError(bin_name + " not found") + return path + +class EvasLauncher(object): + + def __init__(self, src, metadata): + kernel_name_placeholder = "KERNEL_NAME_PLACEHOLDER" + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items() if value != 'constexpr'} + # constants has no use here and can be removed + launcher_src = _generate_launcher(constants, signature, kernel_name_placeholder) + # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name + # in the following launch function. + self.launch = compile_module(signature, constants, launcher_src, src.fn.arg_names, kernel_name_placeholder) + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class EvasUtils(object): + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(EvasUtils, cls).__new__(cls) + return cls.instance + + # Note: + # nvidia and amd backends have their corresponding driver.c file that exposes + # get_device_properties and load_binary using python bindings. + # (see third_party/nvidia/backend/driver.c) + # These methods are then used in compiler.py to initialize handles before running + # the triton kernels. + # Since we recompile the kernel every time (see compile_module above), + # and the metadata generated by these functions aren't applicable to the cpu + # backend, just define the same functions with dummy implementation. + @staticmethod + def get_device_properties(device): + return { + "max_shared_mem": 2**20, + "multiprocessor_count": None, + "sm_clock_rate": None, + "mem_clock_rate": None, + "mem_bus_width": None, + } + + # Important note: + # Since we cannot easy pass function pointers around, we pass along the + # assembly source code so that compile_module above can recompile the + # module every time. + @staticmethod + def load_binary(name, kernel_obj, shared, device): + return ( + None, # module + kernel_obj, # function + None, # n_regs + 4, # n_spills + sys.maxsize, # n_max_threads + ) + + +class EvasDriver(DriverBase): + + def __init__(self): + super().__init__() + self.utils = EvasUtils() + self.launcher_cls = EvasLauncher + self.binary_ext = "elf" ## cc or elf?? + + # CPU driver won't be automatically chosen unless explicitly set through + # triton.runtime.driver.set_active(CPUDriver()) + @staticmethod + def is_active(): + return True + + def map_python_to_cpp_type(self, ty: str) -> str: + """ + Converts a Triton type string to its corresponding C++ type string for EV backend. + """ + type_mapping = { + "i1": "int8_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint8_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + "*fp32": "float*", + "*fp16": "float*", + "*i1": "int8_t*", + "*i8": "int8_t*", + "*i16": "int16_t*", + "*i32": "int32_t*", + "*i64": "int64_t*", + "*u8": "uint8_t*" + } + return type_mapping.get(ty, "void*") + + def get_device_capability(self): + return ("epu", 0) + + def get_active_torch_device(self): + """ + Return the active torch device for this backend. + For EV backend, we don't use torch, so return None. + """ + return None + + def get_benchmarker(self): + """ + Return the benchmarking function that this backend should use by default. + For EV backend, we provide a simple benchmarker. + """ + def simple_benchmarker(kernel_call, *, quantiles=None, **kwargs): + # Simple benchmarker that just calls the kernel once + # In a real implementation, you might want to call it multiple times + # and measure performance + result = kernel_call(**kwargs) + if quantiles is None: + quantiles = [0.5] # Default to median + return [0.0] * len(quantiles) # Placeholder timing + return simple_benchmarker + + def get_current_stream(self, device): + return None + + def get_current_device(self): + # CPU doesn't have a device to return. Return something. + return "epu" + + def set_current_device(self, device): + # CPU doesn't have a device to set + assert device == "epu" + return + + def get_current_target(self): + return GPUTarget("epu", 0, 0) + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args + + +# EVAS Compiler + + +def _make_ttir(mod, metadata, opt): + ttir_code = str(mod) + pattern = r"tt\.func public @(.*?)(?=\()" + matches = re.findall(pattern, ttir_code) + metadata["name"] = matches[0] + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) ## triton 3.7 deleted + passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) + pm.run(mod, "make_ttir") + return mod.str() + + +def _get_evas_triton_opt_path() -> str: + path = os.getenv("EVAS_TRITON_OPT_PATH", None) + if path == None: + raise Exception( + "EVAS_TRITON_OPT_PATH is not set. Build FlagTree with the EVAS backend or set EVAS_TRITON_OPT_PATH." + ) + return path + + + +def ttshared_opt_command(src, dst, metadata): + evas_triton_opt_path = _get_evas_triton_opt_path() + ret = [ + evas_triton_opt_path, + src, + "--pass-pipeline=builtin.module(triton-to-evas)", + "-o", + dst, + ] + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + if enable_ir_dump: + fn_dump_manager = get_dump_manager(metadata["hash"]) + ret.extend( + [ + "--mlir-print-ir-after-all", + "--mlir-print-ir-tree-dir=" + fn_dump_manager.cache_dir + "/ttshared_ir", + ] + ) + return ret + + +def _ttir_to_ttsharedir(ttir_code, metadata): + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tt.mlir") + dst_path = os.path.join(tmpdir, "ttshared.mlir") + Path(src_path).write_text(ttir_code) + cmd = ttshared_opt_command(src_path, dst_path, metadata) + run_command(cmd) + return Path(dst_path).read_text() + + +def _optimize_ttsharedir(ttsharedir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return ttsharedir + +def _rename_entry_func_to_kernel(ttsharedir: str, metadata) -> str: + original_name = metadata.get("name") + if original_name == "kernel": + return ttsharedir + if original_name: + ttsharedir = re.sub( + rf'(\bsym_name\s*=\s*)"{re.escape(original_name)}"', + r'\1"kernel"', + ttsharedir, + count=1, + ) + ttsharedir = re.sub( + rf'(\bfunc\.func\s+(?:public\s+|private\s+)?@){re.escape(original_name)}\b', + r'\1kernel', + ttsharedir, + count=1, + ) + if 'sym_name = "kernel"' not in ttsharedir and "func.func @kernel" not in ttsharedir: + ttsharedir = re.sub(r'(\bsym_name\s*=\s*)"[^"]+"', r'\1"kernel"', ttsharedir, count=1) + return ttsharedir + + +def pack_header(ccsrc: str, metadata): + def _add_extern_device_to_kernel(ccsrc: str) -> str: + # Find the line containing "MCU void kernel(" and add extern "C" __global__ prefix + # Replace "MCU void kernel(" with "extern \"C\" __global__ MCU void kernel(" in the entire ccsrc + ccsrc = re.sub(r'MCU\s+void\s+kernel\(', '__device__ MCU void kernel(', ccsrc) + return ccsrc + + def _add_device_to_func_def(ccsrc: str) -> str: + # Find all function definitions that are not "kernel" and add __device__ prefix + # This regex matches function definitions like "void func_name(" or "int func_name(" etc. + # but excludes the kernel function + # Add __device__ only to function definitions that start with "MCU ALWAYS_INLINE" and are not kernel + pattern = r'(?!.*__device__)(\bMCU.+\([^;{]*\)\s*{)' + # matches = list(re.finditer(pattern, ccsrc, flags=re.MULTILINE)) + # for match in matches: + # print(match.group(0)) + ccsrc = re.sub( + pattern, # Match MCU ALWAYS_INLINE function def but not kernel + r'__device__ \1', + ccsrc, + flags=re.MULTILINE + ) + return ccsrc + + def _extract_kernel_function_args(ccsrc: str): + """ + Extract the kernel function arguments (names and types) from C/C++ code. + + Args: + cc_code (str): The C/C++ code string + + Returns: + dict: A dictionary where keys are argument names and values are their types + Returns None if no kernel function is found + """ + # Pattern to match the kernel function declaration + # Looks for "void kernel(" or "MCU void kernel(" or similar patterns + kernel_pattern = r'(?:__global__\s+)?(?:extern\s+"C"\s+)?(?:MCU\s+)?(?:__device__\s+)?void\s+kernel\s*\((.*?)\)' + + # Find the kernel function + kernel_match = re.search(kernel_pattern, ccsrc, re.DOTALL) + + if not kernel_match: + raise RuntimeError("No kernel function found in the provided C/C++ code.") + + # Extract the arguments section + args_section = kernel_match.group(1) + + # Handle empty arguments case + if args_section.strip() == "": + raise RuntimeError("Kernel function has no arguments.") + + # Split multiple arguments + args = [] + + # Handle complex arguments with nested commas (like template parameters) + depth = 0 + current_arg = "" + + for char in args_section: + if char == ',' and depth == 0: + args.append(current_arg.strip()) + current_arg = "" + else: + if char == '<': + depth += 1 + elif char == '>': + depth -= 1 + current_arg += char + + # Add the last argument + if current_arg.strip(): + args.append(current_arg.strip()) + + # Process each argument to extract type and name + arg_dict = [] + + for arg in args: + # Pattern to match type and name: everything before the last word is the type + arg = arg.strip() + + # Skip empty arguments + if not arg: + continue + + # Handle pointers and references in names + name_pattern = r'(\w+)(?:\s*\[\s*\w*\s*\])*\s*$' + name_match = re.search(name_pattern, arg) + + if name_match: + name = name_match.group(1) + # Extract the type by removing the name from the end + type_part = arg[:arg.rfind(name)].strip() + + # Clean up any trailing spaces, asterisks should stick with the type + if type_part.endswith('*'): + while type_part.endswith(' *'): + type_part = type_part[:-2] + '*' + + arg_dict.append({"name": name, "type": type_part}) + else: + # For arguments that don't match the expected pattern + raise ValueError(f"Could not parse argument: '{arg}'") + + return arg_dict + + def _delete_MCU(ccsrc: str) -> str: + # Replace all occurrences of "MCU" with "" + return re.sub(r'\bMCU\s*', '', ccsrc) + + def _replace_wt_data_setction(ccsrc: str) -> str: + # Replace all occurrences of ".wtdata" with ".data" + return ccsrc.replace('.wtdata', '.data') + + def _delete_always_inline(ccsrc: str) -> str: + # Replace all occurrences of "ALWAYS_INLINE" with "" + return re.sub(r'\bALWAYS_INLINE\s*', '', ccsrc) + + kernel_name = metadata["name"] + ccsrc = _add_device_to_func_def(ccsrc) + ccsrc = _delete_MCU(ccsrc) + #ccsrc = _delete_always_inline(ccsrc) + ccsrc = _replace_wt_data_setction(ccsrc) + args = _extract_kernel_function_args(ccsrc) + outer_args = args[:-3] + gridx, gridy, gridz = outer_args[-3:] + return f""" +#ifdef __AC_DEVICE_COMPILE__ +#include "visa_defs_v1.h" +#include "visa.h" +#include "operator/te/te.h" +#include "target/evamind/evamind-v1/visa_evamind_me.h" +#include "target/evamind/visa_evamind_te.h" +#include "target/evamind/visa_evamind_vcall.h" +#include "util/tuple.h" +#include "visa_matrix.h" + +#include +using namespace visa::te; +using namespace visa::matrix; +using namespace visa; +#endif + +#include +#include +{ccsrc} + +__device__ int get_hw_core_id(int dim_x, int dim_y, int dim_z, int pid_x, int pid_y, int pid_z){{ + int total_size = dim_x * dim_y * dim_z; + int total_core_num = {CORE_NUM} * {CLUSTER_NUM}; + // Calculate elements per core (rounded up) + int elements_per_core = (total_size + total_core_num - 1) / total_core_num; + // Calculate linear index + int linear_index = pid_x + pid_y * dim_x + pid_z * dim_x * dim_y; + // Calculate target core + int hw_id = linear_index / elements_per_core; + // Ensure we don't exceed available cores + hw_id = hw_id % total_core_num; + return hw_id; +}} + + +extern "C" __global__ void {kernel_name}({', '.join([f"{arg['type']} {arg['name']}" for arg in outer_args])}) {{ + int core_id = 0; + int cluster_id = 0; + for(int z = 0; z < {gridz['name']}; z++) {{ + for(int y = 0; y < {gridy['name']}; y++) {{ + for(int x = 0; x < {gridx['name']}; x++) {{ + int hw_core_id = get_hw_core_id({gridx['name']}, {gridy['name']}, {gridz['name']}, x, y, z); + cluster_id = hw_core_id / {CORE_NUM}; + core_id = hw_core_id % {CORE_NUM}; + if (cluster_id == ClusterID && core_id == CoreID) {{ + printf("[KERNEL LOG]: run program with core_id: %d, cluster_id: %d\\n", core_id, cluster_id); + kernel({', '.join([f"{arg['name']}" for arg in outer_args])}, x, y, z); + }} + }} + }} + }} +}} +""" + +def get_py_include_dir(): + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + return sysconfig.get_paths(scheme=scheme)["include"] + +def compile_mod_command(src_path: str, dst_path: str): + return [ + "evcc", + "-O2", + "-Wall", + "-Werror", + "-std=c++17", + "-fno-omit-frame-pointer", + "-Wno-unused-command-line-argument", + "-Wno-implicitly-unsigned-literal", + "-Wno-unused-variable", + "-ftls-model=local-exec", + "-fno-common", + "-ffast-math", + "-ffunction-sections", + "-fdata-sections", + "-Wl,--gc-sections", + "-D__riscv_v_vlen=1024", + # "--gcc-toolchain=/usr", + "-shared", "-fPIC", "-o", + dst_path, + src_path, + "-I", + get_py_include_dir(), + "-I", + os.path.join(os.path.dirname(__file__), "include/evas") + ] + + +def cc_to_kernel_elf_command(src_path: str, dst_path: str): + visa_dir = os.environ.get("VISA_PATH") + if visa_dir is None: + sdk_root = os.environ.get("EVAS_SDK_ROOT") + if sdk_root: + visa_dir = os.path.join(sdk_root, "application", "visa") + if visa_dir is None: + raise EnvironmentError("Please set VISA_PATH or EVAS_SDK_ROOT.") + hgss_root = os.environ.get("RISCV_HGSS") + return [ + "evcc", + "-O2", + "-Wall", + "-Werror", + "-std=c++17", + "-fno-omit-frame-pointer", + "-Wno-unused-command-line-argument", + "-Wno-implicitly-unsigned-literal", + "-Wno-unused-variable", + "-ftls-model=local-exec", + "-fno-common", + "-ffast-math", + "-ffunction-sections", + "-fdata-sections", + "-Wl,--gc-sections", + "-D__riscv_v_vlen=1024", + "-D__evamind_v1", + "-D__EVAMIND_DEVICE_V1", + # "--gcc-toolchain=/usr", + "--ac-device-only", + "-emit-device-elf", + "-x", + "ac", + "-o", + dst_path, + src_path, + "-I", + visa_dir + "/include", + "-I", + visa_dir + "/include/target/evamind/evamind-v1", + *( + ["-I", os.path.join(hgss_root, "misc", "include")] + if hgss_root + else [] + ), + "-I", + os.path.join(os.path.dirname(__file__), "include/evas") + ] + + +def compile_module_from_src(cc: bytes): + dst_name = "kernel" + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, dst_name + ".cc") + Path(src_path).write_bytes(cc) + dst_path = os.path.join(tmpdir, dst_name + ".so") + command = compile_mod_command(src_path, dst_path) + run_command(command) + return Path(dst_path).read_bytes() + +def compile_elf(cc: bytes, metadata): + dst_name = "kernel" + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, dst_name + ".cc") + Path(src_path).write_bytes(cc) + dst_path = os.path.join(tmpdir, dst_name + ".elf") + command = cc_to_kernel_elf_command(src_path, dst_path) + run_command(command) + return Path(dst_path).read_bytes() + +def evofc_lib_dir(): + evofc_dir = os.environ.get("EVOFC_OPT_PATH") + if evofc_dir is not None: + evofc_lib_dir = os.environ.get("EVOFC_OPT_PATH") + else: + site_packages_path = site.getsitepackages()[0] + evofc_lib_dir = os.path.join(site_packages_path, 'jax_plugins','evas_epu') + try: + if evofc_lib_dir == None or not os.path.exists(evofc_lib_dir): + raise FileNotFoundError( + f"Error: '{evofc_lib_dir}' does not exist, please set EVOFC_OPT_PATH." + ) + except FileNotFoundError as e: + print(e) + return evofc_lib_dir + + +def ev_opt_command(src: str, dst: str, metadata): + # visa-debug-print=1 visa-debug-core=0 + ev_opt = os.path.join(evofc_lib_dir(), "evofc-opt") + hash_suffix = metadata["hash"][:5] + incbin_dir = get_cache_manager(_cache_key(f"incbin/incbin_{hash_suffix}")).cache_dir + passes = [ + "--operands-memory-reuse-setter", + "--rewrite-reduce-copy", + "--preprocess-copy-elimination", + "--lower-affine", + "--canonicalize", + "--cse", + "--rewrite-memref-cast", + "--copy-elimination", + "--cse", + "--mc-op-collapse", + "--ev-canonicalize", + "--canonicalize", + "--cse", + "--rewrite-memref-copy", + "--rewrite-slice-deslice-to-copy", + "--ev-decompose-linalg-ops", + "--ev-collapse-dimensions", + "--rewrite-multi-dims-linalg", + "--canonicalize", + "--cse", + "--linalg-named-to-visa=simt-level=simt", + "--kernel-to-visa", + "--canonicalize", + "--cse", + "--visa-te-canonicalizer", + "--canonicalize", + "--cse", + "--rewrite-matmul-with-batch", + "--ev-collapse-dimensions-visa", + "--rewrite-visa-shape-and-stride", + "--rewrite-multi-dims-visa", + "--rewrite-visa-deform-conv2d", + "--rewrite-visa-conv2d-layout", + "--rewrite-visa-conv2d-to-matmul", + "--rewrite-visa-conv2d-kernel-hw", + "--canonicalize", + "--cse", + "--rewrite-strided-visa", + "--canonicalize", + "--cse", + "--inject-sync=simt-level=simt", + f"--visa-to-emitc=simt-level=simt incbin-dir={incbin_dir}", + "--inject-memref-transfer-func", + "--canonicalize", + "--cse", + "--func-to-emitc", + "--canonicalize", + "--cse", + "--ev-scf-to-emitc=simplify-line-break=true", + "--form-expressions", + ] + ev_opt_command = [ + ev_opt, + *passes, + src, + "-o", + dst, + ] + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + if enable_ir_dump: + fn_dump_manager = get_dump_manager(metadata["hash"]) + ev_opt_command.extend( + [ + "--mlir-print-ir-after-all", + "--mlir-print-ir-tree-dir=" + fn_dump_manager.cache_dir + "/evopt_ir", + ] + ) + return ev_opt_command + +def ev_trans_command(src: str, dst: str): + ev_trans = os.path.join(evofc_lib_dir(), "evofc-translate") + ev_trans_command = [ev_trans, "-mlir-to-cpp", src, "-o", dst] + return ev_trans_command + + + + +def _ttshared_to_cc(ttsharedir: str, metadata): + with tempfile.TemporaryDirectory() as tmpdir: + ### stub code for debugging + src_path = os.path.join(tmpdir, "ttshared.mlir") + Path(src_path).write_text(_rename_entry_func_to_kernel(ttsharedir, metadata)) + emitc_path = os.path.join(tmpdir, "emitc.mlir") + cc_path = os.path.join(tmpdir, "dst.cc") + + run_command(ev_opt_command(src_path, emitc_path, metadata)) + run_command(ev_trans_command(emitc_path, cc_path)) + return pack_header(Path(cc_path).read_text(), metadata) + +def _cc_to_evbin(ccsrc: str, metadata): + return compile_elf(ccsrc.encode("utf-8"), metadata) + + +#### Add EVBackend here +@dataclass(frozen=True) +class EvasOptions: + debug: bool = False + arch: str = None + num_warps: int = 0 + num_ctas: int = 0 + num_stages: int = 1 + instrumentation_mode: str = "" + one_tile_per_cta: bool = False + enable_warp_specialization: bool = False + enable_fp_fusion: bool = False + extern_libs = None + cluster_dims: tuple = (1, 1, 1) + shared: bool = False + allow_fp8e4nv: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee",) + # TODO: just for lowering to extern elementwise op (libdevice.py) + # some operations like libdevice.rsqrt and so on + backend_name = "cuda" + to_ttsharedir: bool = False + sanitize_overflow: bool = True + # Disable FP8 here since this is a sample CPU backend. + # Target specific backends can eanble it with supported types. + supported_fp8_dtypes: Tuple[str] = () + ir_override: str = None + def __post_init__(self): + pass + + def hash(self): + key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class EvasBackend(BaseBackend): + binary_ext = "elf" + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == "epu" + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + + def parse_options(self, opts) -> Any: + args = {"arch": self.target.arch} + args.update( + {k: opts[k] for k in EvasOptions.__dataclass_fields__.keys() if k in opts} + ) + return EvasOptions(**args) + + def get_codegen_implementation(self, options): + codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)} + return codegen_fns + + def pack_metadata(self, metadata): + # Note: We actually don't need any of these except for the name which is + # used in the launch function in driver.py. Putting these in so we're + # consistent with other backends + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + metadata.name, + ) + + # Our compilation pipeline isn't in python like nvidia or amd, no need to load + # dialects. See `triton_shared.cc` + def load_dialects(self, ctx): + if triton_shared is not None and hasattr(triton_shared, "load_dialects"): + triton_shared.load_dialects(ctx) + return + libtriton.evas.load_dialects(ctx) + return + + def add_stages(self, stages, options, language=None): + stages["ttir"] = lambda src, metadata: _make_ttir(src, metadata, options) + if options.to_ttsharedir: + stages["elf"] = lambda src, metadata: _optimize_ttsharedir( + _ttir_to_ttsharedir(src, metadata) + ) + else: + stages["ttsharedir"] = lambda src, metadata: _optimize_ttsharedir( + _ttir_to_ttsharedir(src, metadata) + ) + stages["cc"] = lambda src, metadata: _ttshared_to_cc(src, metadata) + stages["elf"] = lambda src, metadata: _cc_to_evbin(src, metadata) + + def get_module_map(self) -> Dict[str, ModuleType]: + """ + Return a map of interface modules to their device-specific implementations + """ + # TODO: maybe no need to add module mapping for EV backend + return {} + + + @functools.lru_cache() + def hash(self): + return self.target diff --git a/third_party/evas/backend/include/epu/memory.h b/third_party/evas/backend/include/epu/memory.h new file mode 100644 index 0000000000..7f8fd21a49 --- /dev/null +++ b/third_party/evas/backend/include/epu/memory.h @@ -0,0 +1,296 @@ +/* Copyright 2024 The EVAS Intelligence Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef EV_SUPPORT_MEMORY_H_ +#define EV_SUPPORT_MEMORY_H_ + + +#include "mlir/Dialect/Affine/Analysis/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/StringRef.h" + +#include +#include + + +namespace mlir::ev { + +static constexpr llvm::StringRef scopeName = "scope"; +static constexpr llvm::StringRef addrName = "addr"; +static constexpr llvm::StringRef subKernelAddrName = "address"; +static constexpr llvm::StringRef phyAddrName = "phyAddr"; +static constexpr llvm::StringRef overflowName = "overflow"; +static constexpr llvm::StringRef previewName = "preview"; +static constexpr llvm::StringRef priorityName = "priority"; +static constexpr llvm::StringRef schedulePrimitive = "schedule_primitive"; +static constexpr llvm::StringRef prefetchName = "prefetch"; +static constexpr llvm::StringRef MEMSCOPE = "mem_scope"; + +enum MemScope : unsigned { + UNKNOWN = 0, + DDR = 1, + L2 = 2, + MM = 3, + PAM = 4, + FAM = 5, + MAX, +}; + +typedef enum { + PRIOR_MIN = 0, + PRIOR_LOW = 25, + PRIOR_MEDIUM = 50, + PRIOR_HIGH = 75, + BANK_ALONE_THRESHOLD = 100, + PRIOR_MAX = 200, +} BufferPrior; + +static inline llvm::StringRef getScopedPtrName(MemScope mem) { + switch (mem) { + case FAM: + case PAM: + return "am_ptr"; + case L2: + return "l2_ptr"; + case MM: + return "mm_ptr"; + case DDR: + return "ddr_ptr"; + default: + return ""; + } +} + +static inline llvm::StringRef memScopeToString(MemScope mem) { + assert(static_cast(mem) < static_cast(MemScope::MAX) && + "Unexpected memory scope."); + static llvm::StringRef names[] = {"UNKNOWN", "DDR", "L2", "MM", "PAM", "FAM"}; + return names[static_cast(mem)]; +} + +static inline llvm::StringRef LiveBufString(MemScope mem) { + assert(static_cast(mem) < static_cast(MemScope::MAX) && + "Unexpected memory scope."); + static llvm::StringRef names[] = {"liveUnknown", "liveDDR", "liveL2", + "liveMM", "livePAM", "liveFAM"}; + return names[static_cast(mem)]; +} + +static inline llvm::StringRef FixedBufString(MemScope mem) { + assert(static_cast(mem) < static_cast(MemScope::MAX) && + "Unexpected memory scope."); + static llvm::StringRef names[] = {"fixedUnknown", "fixedDDR", "fixedL2", + "fixedMM", "fixedPAM", "fixedFAM"}; + return names[static_cast(mem)]; +} + +static inline MemScope stringToMemScope(llvm::StringRef mem) { + static std::unordered_map map = { + {"DDR", MemScope::DDR}, {"L2", MemScope::L2}, + {"MM", MemScope::MM}, {"PAM", MemScope::PAM}, + {"FAM", MemScope::FAM}, {"liveDDR", MemScope::DDR}, + {"liveL2", MemScope::L2}, {"liveMM", MemScope::MM}, + {"livePAM", MemScope::PAM}, {"liveFAM", MemScope::FAM}}; + auto it = map.find(mem.str()); + return (it != map.end()) ? it->second : MemScope::UNKNOWN; +} + +static inline MemScope memLower(MemScope mem) { + assert((mem == MemScope::MM || mem == MemScope::L2) && + "Unexpected memory scope."); + return MemScope(static_cast(mem) - 1); +} + +// 0x1000000000ULL is the base offset of the DDR memory +// 0x40000000 is 1GB space for code section and stack/heap memory +static inline int64_t ddrAddr(int64_t offset) { return 0x1000000000ULL + 0x40000000 + (offset); } + +static inline int64_t memCapacity(MemScope mem) { + assert(static_cast(mem) < static_cast(MemScope::MAX) && + "Unexpected memory scope."); + static std::unordered_map map = { + {MemScope::DDR, 5 * 1024 * 1024 * 1024U}, + {MemScope::L2, 9 * 1024 * 1024U}, + {MemScope::MM, 3 * 512 * 1024U}, + {MemScope::PAM, 256 * 1024U}, + {MemScope::FAM, 256 * 1024U}}; + auto it = map.find(mem); + return (it != map.end()) ? it->second : 0U; +} + +static inline int64_t getPreferedAlignBytes() { return 128; } + +static inline bool memHasBank(MemScope mem) { + if (mem == MemScope::L2 || mem == MemScope::MM || mem == MemScope::PAM || + mem == MemScope::FAM) + return true; + return false; +} + +static inline int64_t memBankAlignment(MemScope mem) { + assert(memHasBank(mem) && "Unexpected memory scope."); + static std::unordered_map map = { + {MemScope::L2, 3 * 1024 * 1024U}, + {MemScope::MM, 256 * 1024U}, + {MemScope::PAM, 128 * 1024U}, + {MemScope::FAM, 128 * 1024U}}; + auto it = map.find(mem); + return (it != map.end()) ? it->second : getPreferedAlignBytes(); +} + +static inline MemScope getMemScope(Type type) { + assert(isa(type) && "Unexpected type"); + MemRefType memType = cast(type); + return static_cast(memType.getMemorySpaceAsInt()); +} + +static inline MemScope getMemScope(Value value) { + return getMemScope(value.getType()); +} + +static inline MemScope getMemScope(Operation *op) { + if (isa(op)) { + MemRefType memType = cast(op).getType(); + return getMemScope(memType); + } else if (isa(op)) { + auto memSpace = cast(op).getMemorySpace(); + if (memSpace && isa(*memSpace)) + return static_cast(cast(*memSpace).getInt()); + return MemScope::UNKNOWN; + } else { + assert(false && "Unexpected Operation"); + } +} + +static inline void setMemScope(Value value, MemScope scope) { + assert(isa(value.getType()) && "Unexpected type"); + MemRefType memType = cast(value.getType()); + Type AttrType = IntegerType::get(memType.getContext(), 64); + Attribute mem_scope = IntegerAttr::get(AttrType, scope); + auto newType = MemRefType::Builder(memType).setMemorySpace(mem_scope); + value.setType(newType); +} + +static inline void setMemScope(Operation *op, MemScope scope) { + if (isa(op)) { + setMemScope(cast(op).getResult(), scope); + } else if (isa(op)) { + Type AttrType = IntegerType::get(op->getContext(), 64); + Attribute mem_scope = IntegerAttr::get(AttrType, scope); + cast(op).setMemorySpaceAttr(mem_scope); + } else { + assert(false && "Unexpected Operation"); + } +} + +static inline uint32_t getMemorySize(Type type) { + if (isa(type)) { + MemRefType memType = cast(type); + return affine::getIntOrFloatMemRefSizeInBytes(memType).value(); + } else if (isa(type)) { + TensorType tensorType = cast(type); + auto memType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + return affine::getIntOrFloatMemRefSizeInBytes(memType).value(); + } else { + assert(false && "Unexpected type"); + } +} + +static inline uint32_t getMemorySize(Operation *op) { + if (isa(op)) { + return getMemorySize(cast(op).getType()); + } else if (isa(op)) { + return getMemorySize(cast(op).getType()); + } else { + assert(false && "Unexpected Operation"); + } +} + +static inline Value getMemoryValue(Operation *op) { + if (isa(op)) { + return cast(op).getResult(); + } else if (isa(op)) { + return cast(op).getResult(); + } else { + assert(false && "Unexpected Operation"); + } +} + +static inline bool isSubkernelBufferOp(Operation *op) { + return false; + //return isa(op) && op->hasAttr(schedulePrimitive); +} + +static inline bool isMemoryAllocOp(Operation *op) { + return isa(op) || isa(op); +} + +static inline void +setMemoryPrior(Operation *op, int64_t priority = BufferPrior::PRIOR_MEDIUM) { + priority = std::max(BufferPrior::PRIOR_MIN, priority); + priority = std::min(BufferPrior::PRIOR_MAX, priority); + Type AttrType = IntegerType::get(op->getContext(), 64); + Attribute prior = IntegerAttr::get(AttrType, priority); + op->setAttr(priorityName, prior); +} + +static inline void +setBankAlonePrior(Operation *op, int64_t priority = BufferPrior::PRIOR_MEDIUM) { + priority = std::max(BufferPrior::PRIOR_MIN, priority); + setMemoryPrior(op, priority + BufferPrior::BANK_ALONE_THRESHOLD); +} + +static inline int64_t getMemoryPrior(Operation *op) { + if (auto prior = op->getAttrOfType(priorityName)) + return prior.getInt(); + return BufferPrior::PRIOR_MEDIUM; +} + +/// Return the func::FuncOp called by `callOp`. +static inline func::FuncOp getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = + llvm::dyn_cast_if_present(callOp.getCallableForCallee()); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + +static inline void setAddrAtIndex(Operation *op, unsigned idx, int64_t addr) { + Builder builder(op); + auto oldAddrAttr = op->getAttrOfType(addrName); + SmallVector newAddrs; + for (unsigned i = 0; i < oldAddrAttr.size(); ++i) { + if (i == idx) { + newAddrs.push_back(builder.getI64IntegerAttr(addr)); + } else { + newAddrs.push_back(oldAddrAttr[i]); + } + } + op->setAttr(addrName, builder.getArrayAttr(newAddrs)); +} + +} // namespace mlir::ev + +#endif // EV_SUPPORT_MEMORY_H_ diff --git a/third_party/evas/backend/include/evas/helper_hpe.h b/third_party/evas/backend/include/evas/helper_hpe.h new file mode 100644 index 0000000000..9a714a7796 --- /dev/null +++ b/third_party/evas/backend/include/evas/helper_hpe.h @@ -0,0 +1,269 @@ +/* Copyright (c) 2024, EVAS Technology Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of EVAS Technology Inc nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +/////////////////////////////////////////////////////////////////////////////// +// These are Helper functions for initialization and error checking + +#ifndef _COMMON_HELPER_H +#define _COMMON_HELPER_H + + +// #ifdef __AC_HOST_COMPILE__ +#include +#include +#include +#include +#include +#include + +union gem_float { + struct fp { + unsigned int mantissa : 23; + unsigned int exp : 8; + unsigned int sign : 1; + } B; + float A; +}; + +static inline int ReadBinFile(const char *fileName, void **inputBuff, + unsigned int *fileSize) { + FILE *binFile; + long binFileBufferLen; + + // 打开文件 + binFile = fopen(fileName, "rb"); + if (binFile == NULL) { + printf("open file %s failed\n", fileName); + return -1; + } + + // 移动文件指针到文件末尾,获取文件大小 + fseek(binFile, 0, SEEK_END); + binFileBufferLen = ftell(binFile); + if (binFileBufferLen == 0) { + printf("binfile is empty, filename is %s\n", fileName); + fclose(binFile); + return -1; + } + rewind(binFile); // 重置文件指针到文件开始位置 + + // 分配缓冲区 + *inputBuff = malloc((size_t)binFileBufferLen); + if (*inputBuff == NULL) { + printf("malloc device buffer failed. size is %lu\n", binFileBufferLen); + fclose(binFile); + return -1; + } + + // 读取文件内容到缓冲区 + fread(*inputBuff, sizeof(char), (size_t)binFileBufferLen, binFile); + fclose(binFile); // 关闭文件 + + if (fileSize != NULL) { + *fileSize = (unsigned int)binFileBufferLen; + } + return 0; +} + +// Runtime error messages +template +void check(T result, char const *const func, const char *const file, + int const line) { + if (result) { + fprintf(stderr, "Runtime error at %s:%d code=%d(%s) \"%s\" \n", file, + line, static_cast(result), evGetErrorName(result), + func); + exit(EXIT_FAILURE); + } +} + +// This will output the proper runtime error strings in the event +// that a host call returns an error +#define checkErrors(val) check((val), #val, __FILE__, __LINE__) + +#define checkEqual(val, cmp) \ + if ((val) != (cmp)) { \ + fprintf(stderr, "Runtime error at %s:%d code=%d \n", __FILE__, \ + __LINE__, static_cast(val)); \ + exit(EXIT_FAILURE); \ + } +#define checkNotEqual(val, cmp) \ + if ((val) == (cmp)) { \ + fprintf(stderr, "Runtime error at %s:%d code=%d \n", __FILE__, \ + __LINE__, static_cast(val)); \ + exit(EXIT_FAILURE); \ + } + +// This will output the proper error string when calling evGetLastError +#define getLastError(msg) __getLastError(msg, __FILE__, __LINE__) + +inline void __getLastError(const char *errorMessage, const char *file, + const int line) { + evError_t err = evGetLastError(); + + if (evSuccess != err) { + fprintf(stderr, + "%s(%i) : getLastError() error :" + " %s : (%d) %s.\n", + file, line, errorMessage, static_cast(err), + evGetErrorString(err)); + exit(EXIT_FAILURE); + } +} + +// This will only print the proper error string when calling evGetLastError +// but not exit program incase error detected. +#define printLastError(msg) __printLastError(msg, __FILE__, __LINE__) + +inline void __printLastError(const char *errorMessage, const char *file, + const int line) { + evError_t err = evGetLastError(); + + if (evSuccess != err) { + fprintf(stderr, + "%s(%i) : getLastError() error :" + " %s : (%d) %s.\n", + file, line, errorMessage, static_cast(err), + evGetErrorString(err)); + } +} + +#ifndef MAX +#define MAX(a, b) (a > b ? a : b) +#endif + +#ifndef MIN +#define MIN(a, b) (a < b ? a : b) +#endif + +// Float To Int conversion +inline int ftoi(float value) { + return (value >= 0 ? static_cast(value + 0.5) + : static_cast(value - 0.5)); +} + +// General NPU Device Initialization +inline int npuDeviceInit(int devID) { + int device_count; + checkErrors(evGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, "npuDeviceInit() error: " + "no devices supporting runtime.\n"); + exit(EXIT_FAILURE); + } + + if (devID < 0) { + devID = 0; + } + + if (devID > device_count - 1) { + fprintf(stderr, "\n"); + fprintf(stderr, ">> %d capable NPU device(s) detected. <<\n", + device_count); + fprintf(stderr, + ">> npuDeviceInit (-device=%d) is not a valid" + " NPU device. <<\n", + devID); + fprintf(stderr, "\n"); + return -devID; + } + + int major = -1; + checkErrors(evGetDeviceAttribute(&major, evDevAttrChipHWVersion, devID)); + + if (major < 0) { + fprintf( + stderr, + "npuDeviceInit(): NPU device [%d] does not support runtime. %d\n", + devID, major); + exit(EXIT_FAILURE); + } + + checkErrors(evSetDevice(devID)); + printf("npuDeviceInit() Device [%d]: %d\n", devID, major); + + return devID; +} + +#define ELAPSED_TIME_START() \ + struct timespec start; \ + do { \ + clock_gettime(CLOCK_MONOTONIC, &start); \ + } while (0) + +#define ELAPSED_TIME_END() \ + do { \ + struct timespec end; \ + long elapsed_time; \ + clock_gettime(CLOCK_MONOTONIC, &end); \ + elapsed_time = \ + (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; \ + printf("%s elapsed time: %ld seconds\n", __FILE__, elapsed_time); \ + } while (0) +// #endif +#define EXTERN(var, dtype, nums) extern dtype var[nums] + +#define INCBIN(sym, file, section, align) \ + asm(".section " #section ", \"a\", @progbits\n\t" \ + ".type " #sym ", @object\n\t" \ + ".global " #sym "\n\t" #sym ":\n\t" \ + ".align " #align "\n\t" \ + ".incbin " #file "\n\t" \ + ".align " #align "\n\t" #sym "_end:\n\t"); + +typedef struct { + int cls_nr; + int core_nr; +} kernel_configs; + +static bool init = false; + +static inline int get_rand_clusters(void) { + if (!init) { + srand(time(NULL)); + init = true; + } + return rand() % 8 + 1; +} + +static inline int get_rand_cores(void) { + if (!init) { + srand(time(NULL)); + init = true; + } + return rand() % 4 + 1; +} +#ifdef __AC_DEVICE_COMPILE__ +#define MFN \ + __attribute__(( \ + target("no-zve32x,no-zve32f,no-zve64x,no-zve64f,no-zvfh,no-zvl32b,no-" \ + "zvl64b,no-zvl128b,no-zvl256b,no-zvl512b,no-zvl1024b"))) +#else +#define MFN +#endif + +#endif // COMMON_HELPER_H_ diff --git a/third_party/evas/backend/name.conf b/third_party/evas/backend/name.conf new file mode 100644 index 0000000000..ca7aa48e2b --- /dev/null +++ b/third_party/evas/backend/name.conf @@ -0,0 +1 @@ +evas diff --git a/third_party/evas/backend/triton_evas.cc b/third_party/evas/backend/triton_evas.cc new file mode 100644 index 0000000000..fd9e324747 --- /dev/null +++ b/third_party/evas/backend/triton_evas.cc @@ -0,0 +1,10 @@ +#include + +namespace py = pybind11; + +void init_triton_evas(py::module &&m) { + m.doc() = "EVAS backend bindings for Triton"; + + m.def("is_evas_available", []() { return true; }); + m.def("load_dialects", [](py::object) {}); +} diff --git a/third_party/evas/backend/utils.py b/third_party/evas/backend/utils.py new file mode 100644 index 0000000000..00abbede0c --- /dev/null +++ b/third_party/evas/backend/utils.py @@ -0,0 +1,67 @@ +import os, subprocess, logging, sys + +from pathlib import Path +from functools import wraps +import hashlib + +def _show_perf(): + return os.environ.get("PERF_LOG_PRINT", None) != None + +def _cache_key(s): + return hashlib.sha256(s.encode("utf-8")).hexdigest() + +def run_command(command): + try: + # Execute the command + subprocess.check_call(command) + print(f"Command '{' '.join(command)}' executed successfully.") + except subprocess.CalledProcessError as e: + logging.error( + "An error occurred while executing the command:" + " ".join(command) + ) + logging.error(e.stderr) # Print standard error + sys.exit(1) + +def compile_to_linalg(jit_func, *args, **kwargs): + kwargs["debug"] = None + kwargs["to_ttsharedir"] = True + return jit_func.warmup(*args, grid=[1,], **kwargs) + + +def with_env_vars(**env_vars): + """Decorator to run a function with specific environment variables.""" + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Save original values + original_values = {} + for key, value in env_vars.items(): + original_values[key] = os.environ.get(key) + os.environ[key] = value + + try: + # Run the function with modified environment + return func(*args, **kwargs) + finally: + # Restore original values + for key, value in original_values.items(): + if value is None: + if key in os.environ: + del os.environ[key] + else: + os.environ[key] = value + return wrapper + return decorator + +def ttsharedir_compare(jit_func, *args, **kargs): + ttshared_kernel = compile_to_linalg(jit_func, *args, **kargs) + golden_ir_dir = os.environ.get("EVAS_GOLDEN_IR_DIR") + if golden_ir_dir is None: + raise EnvironmentError("Please set EVAS_GOLDEN_IR_DIR.") + golden_ir_file = Path(os.path.join(golden_ir_dir, ttshared_kernel.name + ".mlir")) + if golden_ir_file.exists() and golden_ir_file.is_file(): + if golden_ir_file.read_bytes() == ttshared_kernel.kernel: + print(f"Compiled IR compared with {golden_ir_file} successfully.") + return True + raise AssertionError(f"compiled IR compared with {golden_ir_file} failed") + raise LookupError(f"{golden_ir_file} is not found") diff --git a/third_party/evas/bin/evas-triton-opt/CMakeLists.txt b/third_party/evas/bin/evas-triton-opt/CMakeLists.txt new file mode 100644 index 0000000000..b01f330d13 --- /dev/null +++ b/third_party/evas/bin/evas-triton-opt/CMakeLists.txt @@ -0,0 +1,30 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(evas-triton-opt evas-triton-opt.cpp PARTIAL_SOURCES_INTENDED) + +llvm_update_compile_flags(evas-triton-opt) +target_include_directories(evas-triton-opt PRIVATE + ${TRITON_SHARED_TOOLS_DIR} +) +target_link_libraries(evas-triton-opt PRIVATE + EvasTritonToEvas + EVLinalgTransforms + TritonTransforms + TritonSharedAnalysis + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + MLIROptLib + MLIRPass + MLIRRegisterAllPasses + MLIRRegisterAllExtensions + MLIRMemRefTransforms + MLIRSparseTensorTransforms + MLIRControlFlowTransforms + MLIRTensorInferTypeOpInterfaceImpl + MLIRTransforms +) + +mlir_check_all_link_libraries(evas-triton-opt) diff --git a/third_party/evas/bin/evas-triton-opt/evas-triton-opt.cpp b/third_party/evas/bin/evas-triton-opt/evas-triton-opt.cpp new file mode 100644 index 0000000000..db426fbe5b --- /dev/null +++ b/third_party/evas/bin/evas-triton-opt/evas-triton-opt.cpp @@ -0,0 +1,43 @@ +#include "RegisterTritonSharedDialects.h" +#include "evas/Conversion/TritonToEvas/TritonToEvasPipeline.h" +#include "evas/Dialect/Linalg/IR/LinalgOpsExt.h" +#include "evas/Transform/Linalg/Passes.h" + +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonSharedDialects(registry); + mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + mlir::scf::registerBufferizableOpInterfaceExternalModels(registry); + mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); + mlir::vector::registerBufferizableOpInterfaceExternalModels(registry); + mlir::linalg::registerBufferizableOpInterfaceExternalModels(registry); + mlir::linalg::registerEvasLinalgOps(registry); + mlir::cf::registerBufferizableOpInterfaceExternalModels(registry); + mlir::ttx::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry); + mlir::registerAllExtensions(registry); + mlir::memref::registerAllocationOpInterfaceExternalModels(registry); + + mlir::triton::evas::registerTritonToEvasPipeline(); + + return mlir::asMainReturnCode( + mlir::MlirOptMain(argc, argv, "EVAS Triton lowering driver\n", registry)); +} diff --git a/third_party/evas/include/evas/Conversion/TritonToEvas/TritonToEvasPipeline.h b/third_party/evas/include/evas/Conversion/TritonToEvas/TritonToEvasPipeline.h new file mode 100644 index 0000000000..484d5fe931 --- /dev/null +++ b/third_party/evas/include/evas/Conversion/TritonToEvas/TritonToEvasPipeline.h @@ -0,0 +1,23 @@ +#ifndef EVAS_CONVERSION_TRITONTOEVAS_TRITONTOEVASPIPELINE_H +#define EVAS_CONVERSION_TRITONTOEVAS_TRITONTOEVASPIPELINE_H + +#include "mlir/Pass/PassManager.h" + +namespace mlir { +class ModuleOp; +template +class OperationPass; +namespace triton { +namespace evas { + +void buildTritonToEvasPipeline(OpPassManager &pm); +void registerTritonToEvasPipeline(); +std::unique_ptr> +createEvasTritonArithToLinalgPass(bool tensorPtrToLinalg = true, + bool transposeReduceToRank0 = true); + +} // namespace evas +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/evas/include/evas/Dialect/Linalg/IR/CMakeLists.txt b/third_party/evas/include/evas/Dialect/Linalg/IR/CMakeLists.txt new file mode 100644 index 0000000000..41eea5d561 --- /dev/null +++ b/third_party/evas/include/evas/Dialect/Linalg/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS LinalgEnumsExt.td) +mlir_tablegen(LinalgOpsExtEnums.h.inc -gen-enum-decls) +mlir_tablegen(LinalgOpsExtEnums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS LinalgOpsExt.td) +mlir_tablegen(LinalgOpsExt.h.inc -gen-op-decls) +mlir_tablegen(LinalgOpsExt.cpp.inc -gen-op-defs) +add_public_tablegen_target(EvasLinalgOpsExtIncGen) diff --git a/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgEnumsExt.td b/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgEnumsExt.td new file mode 100644 index 0000000000..9d2c02001b --- /dev/null +++ b/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgEnumsExt.td @@ -0,0 +1,21 @@ +//===- LinalgEnumsExt.td - EVAS Linalg enum attrs ---------*- tablegen -*-===// + +#ifndef EVAS_LINALGEXT_ENUMS +#define EVAS_LINALGEXT_ENUMS + +include "mlir/IR/EnumAttr.td" + +def EVAS_LINALG_EvaMindCastRndModeAttr : I64EnumAttr< + "EvaMindCastRndMode", "", + [ + I64EnumAttrCase<"RNE", 0>, + I64EnumAttrCase<"RTZ", 1>, + I64EnumAttrCase<"RDN", 2>, + I64EnumAttrCase<"RUP", 3>, + I64EnumAttrCase<"RMM", 4>, + ]> { + let cppNamespace = "::mlir::linalg"; +} + +#endif // EVAS_LINALGEXT_ENUMS + diff --git a/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.h b/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.h new file mode 100644 index 0000000000..e3b32c0df5 --- /dev/null +++ b/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.h @@ -0,0 +1,19 @@ +#ifndef EVAS_DIALECT_LINALG_IR_LINALGOPSEXT_H +#define EVAS_DIALECT_LINALG_IR_LINALGOPSEXT_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/IR/DialectRegistry.h" + +#include "evas/Dialect/Linalg/IR/LinalgOpsExtEnums.h.inc" + +#define GET_OP_CLASSES +#include "evas/Dialect/Linalg/IR/LinalgOpsExt.h.inc" + +namespace mlir::linalg { + +void registerEvasLinalgOps(DialectRegistry ®istry); + +} // namespace mlir::linalg + +#endif // EVAS_DIALECT_LINALG_IR_LINALGOPSEXT_H diff --git a/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.td b/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.td new file mode 100644 index 0000000000..f3ecc82672 --- /dev/null +++ b/third_party/evas/include/evas/Dialect/Linalg/IR/LinalgOpsExt.td @@ -0,0 +1,95 @@ +//===- LinalgOpsExt.td - EVAS Linalg dialect ops ----------*- tablegen -*-===// + +#ifndef EVAS_LINALGEXT_OPS +#define EVAS_LINALGEXT_OPS + +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "evas/Dialect/Linalg/IR/LinalgEnumsExt.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" + +class Evas_LinalgOp traits = []> : + Op; + +def Linalg_CastOp : Evas_LinalgOp<"cast", + [DestinationStyleOpInterface, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Cast operator"; + let description = [{ + Casts elements from the input shaped value into the output shaped value. + }]; + + let arguments = (ins AnyShaped:$input, + AnyShaped:$output, + DefaultValuedAttr:$rounding_mode); + + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `{` `rounding_mode` `:` $rounding_mode `}` + attr-dict + (`->` type($result)^)? + }]; + + let hasVerifier = 1; + let hasFolder = 1; + + let extraClassDeclaration = [{ + ::mlir::Operation::operand_range getInputs() { return getODSOperands(0); } + ::mlir::Operation::operand_range getOutputs() { return getODSOperands(1); } + + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, ValueRange inputs, + ValueRange outputs, + ArrayRef attributes = {}) { + assert(inputs.size() + outputs.size() == 2u && + "mismatched number of parameters"); + odsState.addOperands(inputs[0]); + odsState.addOperands(outputs[0]); + odsState.addAttributes(attributes); + } + + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef attributes = {}) { + assert(inputs.size() + outputs.size() == 2u && + "mismatched number of parameters"); + odsState.addOperands(inputs); + odsState.addOperands(outputs); + odsState.addAttributes(attributes); + odsState.addTypes(resultTensorTypes); + } + + ShapedType getInputOperandType() { + return cast(getInput().getType()); + } + ShapedType getOutputOperandType() { + return cast(getOutput().getType()); + } + int64_t getInputOperandRank() { return getInputOperandType().getRank(); } + int64_t getOutputOperandRank() { return getOutputOperandType().getRank(); } + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; +} + +#endif // EVAS_LINALGEXT_OPS diff --git a/third_party/evas/include/evas/Transform/Linalg/CMakeLists.txt b/third_party/evas/include/evas/Transform/Linalg/CMakeLists.txt new file mode 100644 index 0000000000..a9bb2f0ed9 --- /dev/null +++ b/third_party/evas/include/evas/Transform/Linalg/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name EVLinalg) +add_public_tablegen_target(EVLinalgTransformsIncGen) \ No newline at end of file diff --git a/third_party/evas/include/evas/Transform/Linalg/MemoryAlloc.h b/third_party/evas/include/evas/Transform/Linalg/MemoryAlloc.h new file mode 100644 index 0000000000..74e2d5ae4e --- /dev/null +++ b/third_party/evas/include/evas/Transform/Linalg/MemoryAlloc.h @@ -0,0 +1,111 @@ +//===------------------------- MemoryAlloc.h --------------------*- C++ -*-===// +// +// Copyright 2024 EVAS Intelligence Co.,Ltd. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef EV_TRANSFORMS_MEMORYALLOC_H +#define EV_TRANSFORMS_MEMORYALLOC_H + +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "evas/Transform/Linalg/RegionMemAllocator.h" +#include + +namespace mlir { +class FunctionOpInterface; +} + +namespace mlir::triton::ev { + +typedef enum { + SIZE_PRIOR = 0, + LIVE_RANGE_PRIOR = 1, +} AssignPrior; + +typedef struct { + size_t alignment; + AssignPrior assignPrior; +} AllocPolicy; + +//===----------------------------------------------------------------------===// +// BFS Region Visitor +//===----------------------------------------------------------------------===// + +class BFSRegionVisitor { +private: + bool update = false; + std::queue toVisitRegion; + +public: + BFSRegionVisitor(bool update) : update(update) {} + void visit(Region *rootRegion, const Liveness &LN, + const std::shared_ptr MA); + +private: + void pushSubRegionAndUpdate(Region *visitedRegion, + const std::shared_ptr MA); +}; + +//===----------------------------------------------------------------------===// +// Memory Allocation Implement +//===----------------------------------------------------------------------===// + +class MemoryAllocImpl { +private: + size_t alignment = 128; + bool preview = false; + bool bankopt = false; + CompareBufferT ToAssignOrder; + BFSRegionVisitor BfsRV; + +public: + MemoryAllocImpl(AllocPolicy policy, bool preview, bool update, bool bankopt) + : alignment(policy.alignment), preview(preview), bankopt(bankopt), + BfsRV(update) { + // TODO:: Support more greedy allocation policy + assert(policy.assignPrior == SIZE_PRIOR && "UnSupported policy."); + ToAssignOrder = [](const std::shared_ptr lhs, + const std::shared_ptr rhs) { + if (lhs->getPriority() != rhs->getPriority()) + return lhs->getPriority() > rhs->getPriority(); + if (lhs->size() != rhs->size()) + return lhs->size() > rhs->size(); + return lhs->getSlotIndex() < rhs->getSlotIndex(); + }; + } + + void runOnFuncAtScope(MemScope memScope, FunctionOpInterface func, + const Liveness &LN); + void runOnFunction(FunctionOpInterface func, const Liveness &LN); +}; + +} // namespace mlir::triton::ev + +namespace mlir::triton { + +std::unique_ptr> createMemoryAllocPass(); + +std::unique_ptr> +createMemoryAllocPass(size_t memScope, size_t alignment, bool preview, + bool update, bool bankopt); + +} // namespace mlir::triton + +#endif // EVOFC_TRANSFORMS_MEMORYALLOC_H diff --git a/third_party/evas/include/evas/Transform/Linalg/Passes.h b/third_party/evas/include/evas/Transform/Linalg/Passes.h new file mode 100644 index 0000000000..5fd1d3342b --- /dev/null +++ b/third_party/evas/include/evas/Transform/Linalg/Passes.h @@ -0,0 +1,33 @@ +#ifndef EV_LINALG_TRANSFORMS_PASSES_H_ +#define EV_LINALG_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace ev { + +std::unique_ptr createRemoveLoopIterArgsWithMemrefTypePass(); +std::unique_ptr createRemoveScalarPass(); +std::unique_ptr createRewriteFuncOpArgsTypePass(); +std::unique_ptr createSplitComputationalOpPass(); +std::unique_ptr createInsertDeallocOpPass(); +std::unique_ptr createSetDeviceInfoPass(); +std::unique_ptr createMaterializeAnnotationPass(); +std::unique_ptr createDoubleBufferPass(); +std::unique_ptr createBufferizePass(); +std::unique_ptr createMemoryAllocPass(); +std::unique_ptr createMemoryAllocPass(size_t memScope, size_t alignment, + bool preview, bool bankopt); +std::unique_ptr createEncapsulateLinalgOpPass(); +std::unique_ptr createSetMemRefScopePass(); +std::unique_ptr createMemoryPromotionPass(); +std::unique_ptr createRemoveRedundencyCopyPass(); +std::unique_ptr createRewriteDataTypePass(); +} // namespace ev +#define GEN_PASS_REGISTRATION +#include "evas/Transform/Linalg/Passes.h.inc" +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/evas/include/evas/Transform/Linalg/Passes.td b/third_party/evas/include/evas/Transform/Linalg/Passes.td new file mode 100644 index 0000000000..56624bbf93 --- /dev/null +++ b/third_party/evas/include/evas/Transform/Linalg/Passes.td @@ -0,0 +1,171 @@ +#ifndef EV_LINALG_PASSES +#define EV_LINALG_PASSES + +include "mlir/Pass/PassBase.td" + +def SplitComputationalOp: Pass<"ev-split-computational-op", "mlir::ModuleOp"> { + let summary = "split computational operations as functions"; + let description = [{ + split computational operations as functions to distinct with memory load/store related operations. + }]; + let constructor = "mlir::triton::ev::createSplitComputationalOpPass()"; + let dependentDialects = ["mlir::tensor::TensorDialect", "mlir::linalg::LinalgDialect"]; +} + +def InsertDeallocOp: Pass<"ev-insert-dealloc-op", "mlir::ModuleOp"> { + let summary = "insert memref.dealloc for memref.alloc based on the liveness analysis"; + let description = [{ + insert memref.dealloc for memref.alloc based on the liveness analysis. + }]; + let constructor = "mlir::triton::ev::createInsertDeallocOpPass()"; + let dependentDialects = ["mlir::memref::MemRefDialect"]; +} + +def SetDeviceInfo: Pass<"ev-set-device-info", "mlir::ModuleOp"> { + let summary = "set device info like memscope and corebind to subkernel"; + let description = [{ + set device info like memscope and corebind to subkernel. + }]; + let constructor = "mlir::triton::ev::createSetDeviceInfoPass()"; + let dependentDialects = ["mlir::memref::MemRefDialect"]; +} + +def MemoryAlloc : Pass<"ev-memory-alloc", "mlir::ModuleOp"> { + let summary = "Pass to perform memory allocation for the module"; + let constructor = "mlir::triton::ev::createMemoryAllocPass()"; + let dependentDialects = [ + "mlir::memref::MemRefDialect", + "mlir::func::FuncDialect", + "mlir::bufferization::BufferizationDialect" + ]; + + let options = [ + Option<"memScope", "mem-scope", "size_t", /*default=*/"0", + "Memory scope considered to allocate. 0 indicates all spaces">, + Option<"alignment", "alignment", "size_t", /*default=*/"128", + "Memory address alignment, must be power of 2">, + Option<"preview", "preview", "bool", /*default=*/"false", + "Whether to preview allocation result.">, + Option<"bankopt", "bankopt", "bool", /*default=*/"false", + "Whether to enable bank alone optimization.">, + ]; +} + +def DoubleBuffer : Pass<"ev-double-buffer", "mlir::ModuleOp"> { + let summary = "double buffer optimization"; + let description = [{ + perform double buffer on the specific value in for loops. + }]; + let constructor = "mlir::triton::ev::createDoubleBufferPass()"; + let dependentDialects = ["mlir::tensor::TensorDialect", "mlir::linalg::LinalgDialect"]; +} + +def MaterializeAnnotation : Pass<"ev-materialize-annotation", "mlir::ModuleOp"> { + let summary = "materialize annotation to device info"; + let description = [{ + materialize annotation to device info like memscope and corebind. + }]; + let constructor = "mlir::triton::ev::createMaterializeAnnotationPass()"; + let dependentDialects = ["mlir::memref::MemRefDialect"]; +} + +def RewriteFuncOpArgsType : Pass<"ev-rewrite-func-op-args-type", "mlir::ModuleOp"> { + let summary = "rewrite func op args type to memref"; + let description = [{ + rewrite func op args type to memref. + }]; + let constructor = "mlir::triton::ev::createRewriteFuncOpArgsTypePass()"; +} + +def RemoveLoopIterArgsWithMemrefType : Pass<"ev-linalg-inplace-optimize", "mlir::ModuleOp"> { + let summary = "inplace optimize"; + let description = [{ + inplace optimize. + }]; + let constructor = "mlir::triton::ev::createRemoveLoopIterArgsWithMemrefTypePass()"; +} + +def RemoveScalar : Pass<"remove-scalar", "mlir::ModuleOp"> { + let summary = "Convert the scalar in the ir"; + let constructor = "mlir::triton::ev::createRemoveScalarPass()"; + let description = [{ + Scalar reading is prone to cache consistency issues, so it is necessary to reduce or avoid the occurrence of scalars. + }]; +} + +def Bufferize : Pass<"ev-bufferize", "mlir::ModuleOp"> { + let summary = "bufferize pass adapted for ev backend"; + let description = [{ + bufferize pass adapted for ev backend + }]; + let constructor = "mlir::triton::ev::createBufferizePass()"; + let dependentDialects = ["mlir::memref::MemRefDialect"]; +} + +def EncapsulateLinalgOp : Pass<"ev-encapsulate-linalg-op", "mlir::ModuleOp"> { + let summary = "Encapsulate Linalg operations into sub-functions"; + let description = [{ + This pass extracts Linalg operations (such as linalg.generic, linalg.matmul, etc.) + into separate sub-functions (FuncOp) and inserts a call (CallOp) at the original location, + enabling further optimization and analysis. + }]; + let constructor = "mlir::triton::ev::createEncapsulateLinalgOpPass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::linalg::LinalgDialect" + ]; +} + +def SetMemRefScope : Pass<"ev-set-memref-scope", "mlir::ModuleOp"> { + let summary = "Set memref scope attributes using type converter in dialect conversion"; + let description = [{ + This pass utilizes a type converter in dialect conversion to set the scope attribute + of all memref types following these rules: + 1. If the memref type is scalar, set to ev::MemScope::DDR + 2. If memref.alloc op is met, set the result type to ev::MemScope::MM + 3. Set function signature input arguments memref type to DDR + }]; + let constructor = "mlir::triton::ev::createSetMemRefScopePass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::memref::MemRefDialect" + ]; +} + +def MemoryPromotion : Pass<"ev-memory-promotion", "mlir::ModuleOp"> { + let summary = "Promote memory scope of linalg.matmul operations from MM to AM(5) and add memref.copy back to MM"; + let description = [{ + This pass promote AM/MM and DDR/L2 pair. + }]; + let constructor = "mlir::triton::ev::createMemoryPromotionPass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::linalg::LinalgDialect", + "mlir::memref::MemRefDialect" + ]; +} + +def RemoveRedundencyCopy : Pass<"ev-remove-redundency-copy", "mlir::ModuleOp"> { + let summary = "Optimize redundant memref.copy operations"; + let description = [{ + This pass eliminates redundant memref.copy operations through several optimizations: + 1. Eliminate self-copy: copy(A, A) -> no-op + 2. Chain copy elimination: copy(A, B) + copy(B, C) -> copy(A, C) + 3. Eliminate redundant copy: remove duplicate copy(A, B) operations + }]; + let constructor = "mlir::triton::ev::createRemoveRedundencyCopyPass()"; + let dependentDialects = [ + "mlir::memref::MemRefDialect", + "mlir::func::FuncDialect" + ]; +} + +def RewriteDataType : Pass<"ev-rewrite-data-type", "mlir::ModuleOp"> { + let summary = "Rewrite illegal data types using pattern rewriting"; + let description = [{ + Rewrite illegal data types like i1 to i8 and i64 to i32 using pattern rewriting + }]; + let constructor = "mlir::triton::ev::createRewriteDataTypePass()"; + let dependentDialects = ["mlir::arith::ArithDialect", "mlir::func::FuncDialect"]; +} +#endif diff --git a/third_party/evas/include/evas/Transform/Linalg/RegionMemAllocator.h b/third_party/evas/include/evas/Transform/Linalg/RegionMemAllocator.h new file mode 100644 index 0000000000..a4afd72a38 --- /dev/null +++ b/third_party/evas/include/evas/Transform/Linalg/RegionMemAllocator.h @@ -0,0 +1,268 @@ +//===-------------------- RegionMemAllocator.h ------------------*- C++ -*-===// +// +// Copyright 2024 EVAS Intelligence Co.,Ltd. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef EV_TRANSFORMS_REGIONMEMALLOCATOR_H +#define EV_TRANSFORMS_REGIONMEMALLOCATOR_H + +#include "epu/memory.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" + +namespace mlir::triton::ev { +using namespace mlir::ev; +//===----------------------------------------------------------------------===// +// Living Buffer Class +//===----------------------------------------------------------------------===// + +class LiveBuffer { +public: + /// Create Normal Buffers + LiveBuffer(Operation *op) : op(op) { + memScope = getMemScope(op); + bufferSize = getMemorySize(op); + } + LiveBuffer(Value bufferValue, MemScope memScope, int64_t phyAddr, + uint64_t slotIndex, const std::vector &liveVector) + : bufferValue(bufferValue), memScope(memScope), phyAddr(phyAddr), + slotIndex(slotIndex) { + bufferSize = getMemorySize(bufferValue.getType()); + op = bufferValue.getDefiningOp(); + assert(op != nullptr && + "bufferValue must be a defining value of an operation"); + priority = getMemoryPrior(op); + setLiveInterval(liveVector); + } + /// Create Sub Region Buffers + LiveBuffer(Operation *op, MemScope memScope, int64_t phyAddr, int64_t size) + : op(op), memScope(memScope), phyAddr(phyAddr), bufferSize(size) {} + /// Create External Buffers + LiveBuffer(MemScope memScope, int64_t phyAddr, int64_t size) + : memScope(memScope), phyAddr(phyAddr), bufferSize(size) {} + + void setPhyAddr(int64_t phyAddr) { this->phyAddr = phyAddr; } + void setPriority(int64_t priority) { this->priority = priority; } + void setSlotIndex(uint64_t slotIndex) { + this->slotIndex = static_cast(slotIndex); + } + void setLiveInterval(Operation *liveOp) { + liveInterval.clear(); + liveInterval.insert(liveOp); + } + void setLiveInterval(const std::vector &liveVector) { + liveInterval = std::set(liveVector.begin(), liveVector.end()); + } + + // get operation order index in current region. + int64_t getSlotIndex() { return slotIndex; } + int64_t addr() { return phyAddr; } + int64_t size() { return bufferSize; } + MemScope scope() { return memScope; } + Operation *getOperation() { return op; } + Value value() { return bufferValue ? bufferValue : getMemoryValue(op); } + int64_t getPriority() { return priority; } + int64_t upperBound() { return addr() + size(); } + std::set &getLiveInterval() { return liveInterval; } + int64_t fixedAddr() { + if (memScope == MemScope::DDR) { + return ev::ddrAddr(addr()); + } + return addr(); + } + bool isPreferBankAlone() { + return memHasBank(memScope) && + priority >= BufferPrior::BANK_ALONE_THRESHOLD; + } + bool isOverflow() { + return upperBound() > memCapacity(memScope) || addr() < 0; + } + bool isExternal() { return op == nullptr; } + bool isSubRegionBuf() { + return op != nullptr && !isMemoryAllocOp(op) && !isSubkernelBufferOp(op); + } + bool isConflictWith(std::set &checkInterval) { + for (auto &checkOp : checkInterval) { + if (liveInterval.count(checkOp)) + return true; + } + return false; + } + bool isLiveAt(Operation *checkOp) { + if (liveInterval.count(checkOp)) + return true; + return false; + } + +private: + int64_t slotIndex = -1; + Operation *op = nullptr; + MemScope memScope = UNKNOWN; + int64_t phyAddr = 0; + int64_t bufferSize = 0; + std::set liveInterval; + int64_t priority = BufferPrior::PRIOR_MEDIUM; + Value bufferValue; +}; + +using CompareBufferT = std::function, + const std::shared_ptr)>; +using LiveBufferSet = std::set, CompareBufferT>; + +//===----------------------------------------------------------------------===// +// Buffer Info +//===----------------------------------------------------------------------===// + +struct BufferInfo { + MemScope scope; + int64_t address; + + BufferInfo(MemScope scope = MemScope::UNKNOWN, int64_t address = -1) + : scope(scope), address(address) {} +}; + +//===----------------------------------------------------------------------===// +// Basic Memory Allocator +//===----------------------------------------------------------------------===// + +class MemAllocator { +private: + bool preview; + MemScope memScope; + int64_t alignment; + LiveBufferSet virtualBufs; + LiveBufferSet phyBufsOrder; + LiveBufferSet phyBufsReverse; + +public: + static bool LayoutOrder(const std::shared_ptr lhs, + const std::shared_ptr rhs) { + if (lhs->addr() != rhs->addr()) + return lhs->addr() < rhs->addr(); + assert(!lhs->isExternal() && !rhs->isExternal() && "buffer overlap"); + return lhs->getSlotIndex() > rhs->getSlotIndex(); + } + + static bool LayoutReverse(const std::shared_ptr lhs, + const std::shared_ptr rhs) { + if (lhs->upperBound() != rhs->upperBound()) + return lhs->upperBound() > rhs->upperBound(); + assert(!lhs->isExternal() && !rhs->isExternal() && "buffer overlap"); + return lhs->getSlotIndex() > rhs->getSlotIndex(); + } + + MemAllocator(CompareBufferT ToAssignOrder, MemScope memScope, + int64_t alignment, bool preview) + : preview(preview), memScope(memScope), alignment(alignment), + virtualBufs(ToAssignOrder), phyBufsOrder(LayoutOrder), + phyBufsReverse(LayoutReverse) { + virtualBufs.clear(); + phyBufsOrder.clear(); + phyBufsReverse.clear(); + } + + virtual ~MemAllocator() {} + +private: + void initVirtualBufs(const std::vector &opsToAssign, + const Liveness &LN, uint64_t &nextSlotIndex); + + void initExternelBufs(const ArrayAttr &externalLives, + uint64_t &nextSlotIndex); + // Warning: There is no guarantee for the validity of fixed physical buffers. + void initCurrRegionPhyBufs(const std::vector &opsWithFixedAddr, + const Liveness &LN, uint64_t &nextSlotIndex); + void initSubRegionPhyBufs(const std::vector &opsHasSubRegion, + uint64_t &nextSlotIndex); + void visitOpsInRegion(Region *visitedRegion, + std::vector &opsToAssign, + std::vector &opsWithFixedAddr, + std::vector &opsHasSubRegion, + llvm::DenseMap &subkernelBuffers); + void initRegionBuffers(Region *region, const Liveness &LN, + uint64_t &nextSlotIndex); + +public: + std::shared_ptr pickNextBuffer(); + void insertPhyBuf(const std::shared_ptr buffer); + void erasePhyBuf(const std::shared_ptr buffer); + void revertPhyBuf(const std::shared_ptr buffer); + void assignAddrOrder(const std::shared_ptr visitedBuffer, + int64_t align); + void assignAddrReverse(const std::shared_ptr visitedBuffer, + int64_t align); + virtual void init(Region *region, const Liveness &LN); + virtual void rewrite(); + virtual void reset(); + virtual void allocate() = 0; + + bool isVirtBufEmpty() { return virtualBufs.empty(); } + LiveBufferSet &getAllocResult() { + assert(isVirtBufEmpty() && "Not finish allocate"); + return phyBufsOrder; + } + bool isPreview() { return preview; } + int64_t getAlign() { return alignment; } + MemScope getScope() { return memScope; } +}; + +//===----------------------------------------------------------------------===// +// Order Memory Allocator +//===----------------------------------------------------------------------===// + +class OrderMemAllocator : public MemAllocator { + +public: + OrderMemAllocator(CompareBufferT ToAssignOrder, MemScope memScope, + int64_t alignment, bool preview) + : MemAllocator(ToAssignOrder, memScope, alignment, preview) {} + + void allocate() override; +}; + +//===----------------------------------------------------------------------===// +// Dual-Directional Memory Allocator +//===----------------------------------------------------------------------===// + +class DualMemAllocator : public MemAllocator { +public: + DualMemAllocator(CompareBufferT ToAssignOrder, MemScope memScope, + int64_t alignment, bool preview) + : MemAllocator(ToAssignOrder, memScope, alignment, preview) {} + + void allocate() override; +}; + +//===----------------------------------------------------------------------===// +// Bank Optimization Memory Allocator +//===----------------------------------------------------------------------===// + +class BankOptAllocator : public MemAllocator { +public: + BankOptAllocator(CompareBufferT ToAssignOrder, MemScope memScope, + int64_t alignment, bool preview) + : MemAllocator(ToAssignOrder, memScope, alignment, preview) {} + + void allocate() override; + void rewrite() override; +}; + +} // namespace mlir::triton::ev + +#endif // EVOFC_TRANSFORMS_REGIONMEMALLOCATOR_H diff --git a/third_party/evas/lib/Conversion/TritonToEvas/CMakeLists.txt b/third_party/evas/lib/Conversion/TritonToEvas/CMakeLists.txt new file mode 100644 index 0000000000..5523f03789 --- /dev/null +++ b/third_party/evas/lib/Conversion/TritonToEvas/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(EvasTritonToEvas + TritonArithToLinalgNamed.cpp + TritonToEvasPipeline.cpp + + LINK_LIBS PUBLIC + MLIRBufferizationTransforms + MLIRPass + MLIRTransforms + StructuredToMemref + TritonArithToLinalg + TritonPtrToMemref + TritonToLinalgExperimental + TritonToStructured + TritonToUnstructured + UnstructuredToMemref + EvasLinalgIR + EVLinalgTransforms +) + +target_compile_options(EvasTritonToEvas PRIVATE -Wno-deprecated-declarations) diff --git a/third_party/evas/lib/Conversion/TritonToEvas/TritonArithToLinalgNamed.cpp b/third_party/evas/lib/Conversion/TritonToEvas/TritonArithToLinalgNamed.cpp new file mode 100644 index 0000000000..b6291ba9e3 --- /dev/null +++ b/third_party/evas/lib/Conversion/TritonToEvas/TritonArithToLinalgNamed.cpp @@ -0,0 +1,277 @@ +//===----------------------------------------------------------------------===// +// +// EVAS Triton arithmetic to linalg conversion. +// +// This mirrors triton-shared's triton-arith-to-linalg pass, but lowers the +// tensor elementwise arithmetic needed by EVAS to named linalg ops directly. +// +//===----------------------------------------------------------------------===// + +#include "evas/Conversion/TritonToEvas/TritonToEvasPipeline.h" +#include "evas/Dialect/Linalg/IR/LinalgOpsExt.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton-shared/Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; + +namespace { + +template +class TritonToLinalgNamedConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TritonOp op, typename TritonOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto dstType = dyn_cast(op.getType()); + if (!dstType) + return failure(); + + auto init = rewriter.create( + op.getLoc(), dstType.getShape(), dstType.getElementType()); + auto namedOp = rewriter.create( + op.getLoc(), op.getType(), adaptor.getOperands(), ValueRange(init), + linalg::getPrunedAttributeList(op)); + rewriter.replaceOp(op, namedOp.getResults()); + return success(); + } +}; + +template +class ArithCastToLinalgCastConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithCastOp op, typename ArithCastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto dstType = dyn_cast(op->getResult(0).getType()); + if (!dstType) + return failure(); + + auto init = rewriter.create( + op.getLoc(), dstType.getShape(), dstType.getElementType()); + auto castOp = rewriter.create( + op.getLoc(), TypeRange(dstType), adaptor.getOperands()[0], init); + rewriter.replaceOp(op, castOp.getResults()); + return success(); + } +}; + +void populateEvasElementwiseToLinalgPatterns(RewritePatternSet &patterns) { + patterns + .add, + TritonToLinalgNamedConverter, + TritonToLinalgNamedConverter, + TritonToLinalgNamedConverter, + ArithCastToLinalgCastConverter, + ArithCastToLinalgCastConverter>( + patterns.getContext(), PatternBenefit(2)); +} + +class EvasTritonArithToLinalgPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EvasTritonArithToLinalgPass) + + EvasTritonArithToLinalgPass(bool tensorPtrToLinalg, + bool transposeReduceToRank0) + : tensorPtrToLinalg(tensorPtrToLinalg), + transposeReduceToRank0(transposeReduceToRank0) {} + + StringRef getArgument() const final { return "evas-triton-arith-to-linalg"; } + StringRef getDescription() const final { + return "Convert Triton arithmetic operations to linalg with EVAS named " + "linalg elementwise lowering"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + + { + RewritePatternSet patterns(&getContext()); + mlir::triton::populateTritonArithToLinalgCanonicalizationPatterns( + patterns); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + return; + } + } + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, ttx::TritonTilingExtDialect, + tts::TritonStructuredDialect>(); + target.addLegalOp(); + target.addLegalOp(); + + target.addDynamicallyLegalDialect( + [](Operation *op) { + if (auto constOp = dyn_cast(op)) { + if (!isa(constOp.getResult().getType())) + return true; + if (auto denseAttr = + dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat() && + isa(denseAttr.getElementType())) + return false; + } + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), [](Type type) { + return isa(type); + }); + return !operateOnTensors; + }); + + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [](mlir::triton::AddPtrOp op) { + return !isa(op.getResult().getType()); + }); + target.addDynamicallyLegalOp( + [this](mlir::triton::BitcastOp op) { + if (!tensorPtrToLinalg) + return mlir::triton::isPtrTypeLike(op.getType()); + if (mlir::triton::isPtrTypeLike(op.getType())) + return !isa(op.getType()); + return false; + }); + + if (tensorPtrToLinalg) { + target.addDynamicallyLegalOp([](auto op) { + return !isa(op->getOperands()[0].getType()); + }); + mlir::triton::populateTritonTensorPtrConversionPatterns(patterns); + } + + populateEvasElementwiseToLinalgPatterns(patterns); + mlir::triton::populateTritonArithToLinalgConversionPatterns( + /*pidsToFuncArgs=*/true, /*addptrToLinalg=*/true, + /*assertToCf=*/true, transposeReduceToRank0, patterns); + + addProgramInfo(); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + if (failed(applyTensorConcatDecomposition())) { + signalPassFailure(); + return; + } + convertTritonFuncToFunc(); + } + +private: + static auto constexpr LAUNCH_GRID_RANK = + mlir::triton::getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + + void addProgramInfo() { + for (auto func : getOperation().getOps()) { + OpBuilder b(func); + auto origFuncType = func.getFunctionType(); + SmallVector newInputTypes(origFuncType.getInputs()); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + func.setFunctionType( + b.getFunctionType(newInputTypes, origFuncType.getResults())); + + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + } + + LogicalResult applyTensorConcatDecomposition() { + RewritePatternSet patterns(&getContext()); + tensor::populateDecomposeTensorConcatPatterns(patterns); + return applyPatternsGreedily(getOperation(), std::move(patterns)); + } + + void convertTritonFuncToFunc() { + getOperation().walk([&](mlir::triton::FuncOp func) { + OpBuilder builder(func); + auto funcFunc = func::FuncOp::create( + builder, func.getLoc(), func.getName(), func.getFunctionType()); + funcFunc.setVisibility(func.getVisibility()); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + + IRMapping map; + func.getBody().cloneInto(&funcFunc.getBody(), map); + + for (Block &block : funcFunc.getBody().getBlocks()) { + Operation *term = block.getTerminator(); + if (isa(term)) { + builder.setInsertionPoint(term); + func::ReturnOp::create(builder, func.getLoc(), term->getOperands()); + term->erase(); + } + } + func.erase(); + }); + } + + bool tensorPtrToLinalg; + bool transposeReduceToRank0; +}; + +} // namespace + +std::unique_ptr> +mlir::triton::evas::createEvasTritonArithToLinalgPass( + bool tensorPtrToLinalg, bool transposeReduceToRank0) { + return std::make_unique( + tensorPtrToLinalg, transposeReduceToRank0); +} diff --git a/third_party/evas/lib/Conversion/TritonToEvas/TritonToEvasPipeline.cpp b/third_party/evas/lib/Conversion/TritonToEvas/TritonToEvasPipeline.cpp new file mode 100644 index 0000000000..2bcdd975ca --- /dev/null +++ b/third_party/evas/lib/Conversion/TritonToEvas/TritonToEvasPipeline.cpp @@ -0,0 +1,66 @@ +#include "evas/Conversion/TritonToEvas/TritonToEvasPipeline.h" + +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h" +#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h" +#include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToPtr.h" +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h" +#include "triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h" +#include "evas/Transform/Linalg/Passes.h" + +using namespace mlir; + +namespace mlir::triton::evas { + +void buildTritonToEvasPipeline(OpPassManager &pm) { + pm.addPass(triton::createTritonToStructuredPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(triton::createTritonToUnstructuredPass()); + pm.addPass(createEvasTritonArithToLinalgPass(/*tensorPtrToLinalg=*/true)); + pm.addPass(triton::createStructuredToMemrefPass()); + pm.addPass(triton::createUnstructuredToMemrefPass()); + pm.addPass(triton::createTritonPtrToMemrefPass()); + pm.addPass(triton::createTritonToPtrPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addPass(triton::createReconcilePtrCastsPass()); + pm.addPass(createRemoveDeadValuesPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + // EVAS does not support every scalar/index width that upstream Triton emits. + pm.addPass(ev::createRewriteDataTypePass()); + + bufferization::OneShotBufferizePassOptions bufferizeOptions; + bufferizeOptions.allowReturnAllocsFromLoops = true; + pm.addPass(bufferization::createOneShotBufferizePass(bufferizeOptions)); + + pm.addPass(ev::createRemoveLoopIterArgsWithMemrefTypePass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + pm.addPass(ev::createSetMemRefScopePass()); + pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); + pm.addPass(ev::createRemoveRedundencyCopyPass()); + + pm.addPass(ev::createInsertDeallocOpPass()); + pm.addPass(ev::createMemoryAllocPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); +} + +void registerTritonToEvasPipeline() { + PassPipelineRegistration<>( + "triton-to-evas", + "Lower Triton IR through triton-shared and EVAS backend passes", + [](OpPassManager &pm) { buildTritonToEvasPipeline(pm); }); +} + +} // namespace mlir::triton::evas diff --git a/third_party/evas/lib/Dialect/Linalg/IR/CMakeLists.txt b/third_party/evas/lib/Dialect/Linalg/IR/CMakeLists.txt new file mode 100644 index 0000000000..291c74d626 --- /dev/null +++ b/third_party/evas/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(EvasLinalgIR + LinalgOpsExt.cpp + + DEPENDS + EvasLinalgOpsExtIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalgDialect +) + +target_compile_options(EvasLinalgIR PRIVATE -Wno-deprecated-declarations) + diff --git a/third_party/evas/lib/Dialect/Linalg/IR/LinalgOpsExt.cpp b/third_party/evas/lib/Dialect/Linalg/IR/LinalgOpsExt.cpp new file mode 100644 index 0000000000..4cddab9245 --- /dev/null +++ b/third_party/evas/lib/Dialect/Linalg/IR/LinalgOpsExt.cpp @@ -0,0 +1,271 @@ +#include "evas/Dialect/Linalg/IR/LinalgOpsExt.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" + +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::linalg; + +namespace { + +static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value value, + int64_t dim) { + auto type = cast(value.getType()); + if (!type.isDynamicDim(dim)) + return builder.getIndexAttr(type.getDimSize(dim)); + + return getAsOpFoldResult( + TypeSwitch(value.getType()) + .Case([&](RankedTensorType) -> Value { + return builder.create(loc, value, dim); + }) + .Case([&](MemRefType) -> Value { + return builder.create(loc, value, dim); + })); +} + +static Operation *getSlice(OpBuilder &builder, Location loc, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType) -> Operation * { + return builder.create(loc, source, offsets, + sizes, strides); + }) + .Case([&](MemRefType) -> Operation * { + return builder.create(loc, source, offsets, sizes, + strides); + }) + .Default([](Type) { return nullptr; }); +} + +static void getCastEffects( + CastOp op, + SmallVectorImpl> + &effects) { + if (isa(op.getInput().getType())) + effects.emplace_back(MemoryEffects::Read::get(), + &op->getOpOperand(CastOp::odsIndex_input), + /*stage=*/0, /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + + if (isa(op.getOutput().getType())) { + effects.emplace_back(MemoryEffects::Read::get(), + &op->getOpOperand(CastOp::odsIndex_output), + /*stage=*/0, /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), + &op->getOpOperand(CastOp::odsIndex_output), + /*stage=*/0, /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } +} + +template +static LogicalResult bufferizeDestinationStyleOp( + LinalgOp op, RewriterBase &rewriter, const BufferizationOptions &options, + BufferizationState &state) { + SmallVector operands; + operands.reserve(op->getNumOperands()); + + for (Value operand : op->getOperands()) { + if (!isa(operand.getType())) { + operands.push_back(operand); + continue; + } + + FailureOr buffer = getBuffer(rewriter, operand, options, state); + if (failed(buffer)) + return failure(); + operands.push_back(*buffer); + } + + auto newOp = + rewriter.create(op.getLoc(), TypeRange(), operands, + op->getAttrs()); + auto dstOp = cast(newOp.getOperation()); + replaceOpWithBufferizedValues(rewriter, op.getOperation(), + dstOp.getDpsInits()); + return success(); +} + +} // namespace + +void mlir::linalg::registerEvasLinalgOps(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *, linalg::LinalgDialect *dialect) { + RegisteredOperationName::insert(*dialect); + }); +} + +LogicalResult CastOp::verify() { + ShapedType inputType = getInputOperandType(); + ShapedType outputType = getOutputOperandType(); + if (inputType.getRank() != outputType.getRank()) + return emitOpError("incompatible shape rank"); + + for (auto [inputDim, outputDim] : + llvm::zip_equal(inputType.getShape(), outputType.getShape())) { + if (inputDim != outputDim) + return emitOpError("wrong shape"); + } + return success(); +} + +SmallVector CastOp::getIterationDomain(OpBuilder &builder) { + int64_t rank = getInputOperandRank(); + SmallVector loopBounds(rank); + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = getInput(); + for (int64_t dim = 0; dim < rank; ++dim) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, source, dim); + loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector CastOp::getLoopIteratorTypes() { + return SmallVector(getInputOperandRank(), + utils::IteratorType::parallel); +} + +FailureOr +CastOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + int64_t rank = getInputOperandRank(); + SmallVector strides(rank, builder.getI64IntegerAttr(1)); + + Operation *inputSlice = + getSlice(builder, getLoc(), getInput(), offsets, sizes, strides); + if (!inputSlice) + return emitOpError("failed to compute input slice"); + + Operation *outputSlice = + getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides); + if (!outputSlice) + return emitOpError("failed to compute output slice"); + + SmallVector tiledOperands{inputSlice->getResult(0), + outputSlice->getResult(0)}; + SmallVector resultTypes; + if (hasPureTensorSemantics()) + resultTypes.push_back(tiledOperands[1].getType()); + + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{ + {tiledOp}, + SmallVector(tiledOp->getResults()), + llvm::to_vector(ArrayRef{inputSlice, outputSlice})}; +} + +LogicalResult CastOp::getResultTilePosition( + OpBuilder &, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + if (resultNumber != 0) + return failure(); + resultOffsets.assign(offsets.begin(), offsets.end()); + resultSizes.assign(sizes.begin(), sizes.end()); + return success(); +} + +FailureOr +CastOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) { + SmallVector resultOffsets; + SmallVector resultSizes; + if (failed(getResultTilePosition(builder, resultNumber, offsets, sizes, + resultOffsets, resultSizes))) + return failure(); + return getTiledImplementation(builder, resultOffsets, resultSizes); +} + +LogicalResult CastOp::fold(FoldAdaptor, + SmallVectorImpl &results) { + if (hasPureTensorSemantics()) { + if (getInputOperandType() != getOutputOperandType()) + return failure(); + + auto isInvalid = [](Value value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) + return true; + return isa(defOp); + }; + if (isInvalid(getInput()) || isInvalid(getOutput())) + return failure(); + + results.push_back(getInput()); + return success(); + } + + bool folded = false; + for (OpOperand &operand : getOperation()->getOpOperands()) { + auto cast = operand.get().getDefiningOp(); + if (cast && !isa(cast.getOperand().getType()) && + !cast->hasAttr("no_fold")) { + operand.set(cast.getOperand()); + folded = true; + } + } + return success(folded); +} + +void CastOp::getEffects( + SmallVectorImpl> + &effects) { + getCastEffects(*this, effects); +} + +bool CastOp::bufferizesToMemoryRead(OpOperand &, + const AnalysisState &) { + return true; +} + +bool CastOp::bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &) { + return opOperand.getOperandNumber() == CastOp::odsIndex_output; +} + +AliasingValueList CastOp::getAliasingValues(OpOperand &opOperand, + const AnalysisState &) { + if (opOperand.getOperandNumber() == CastOp::odsIndex_output) + return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; + return {}; +} + +LogicalResult CastOp::bufferize(RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) { + return bufferizeDestinationStyleOp(*this, rewriter, options, state); +} + +#include "evas/Dialect/Linalg/IR/LinalgOpsExtEnums.cpp.inc" + +#define GET_OP_CLASSES +#include "evas/Dialect/Linalg/IR/LinalgOpsExt.cpp.inc" diff --git a/third_party/evas/lib/Transform/Linalg/Bufferize.cpp b/third_party/evas/lib/Transform/Linalg/Bufferize.cpp new file mode 100644 index 0000000000..c595e13bb6 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/Bufferize.cpp @@ -0,0 +1,60 @@ +#include "epu/memory.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include + +#define GEN_PASS_DEF_BUFFERIZE +#include "evas/Transform/Linalg/Passes.h.inc" +namespace mlir::triton::ev { + +namespace { +/// A pass to insert deallocations for allocated buffers after theirlast use. +using namespace mlir; + +struct BufferizePass : public ::impl::BufferizeBase { + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + OpBuilder builder(moduleOp.getContext()); + + bufferization::OneShotBufferizationOptions bufferizeOption; + bufferizeOption.bufferizeFunctionBoundaries = false; + bufferizeOption.allowReturnAllocsFromLoops = true; + bufferization::BufferizationState state; + // default memory scope on ddr + bufferizeOption.defaultMemorySpaceFn = + [](TensorType t) -> std::optional { + return IntegerAttr::get(IntegerType::get(t.getContext(), 64), + mlir::ev::MemScope::DDR); + }; + bufferizeOption.opFilter.allowOperation([](Operation *op) { + // If it's a function, only allow "kernel" + if (auto funcOp = dyn_cast(op)) + return funcOp.getSymName() == "kernel"; + // For other ops, check if they're inside "kernel" + auto parentFunc = op->getParentOfType(); + return parentFunc && parentFunc.getSymName() == "kernel"; + }); + // bufferizeOption.opFilter.denyDialect(); + // bufferizeOption.opFilter.allowDialect(); + // bufferizeOption.opFilter.denyOperation(); + if (failed(bufferization::runOneShotBufferize(moduleOp, bufferizeOption, + state))) { + signalPassFailure(); + } + } +}; + +} // namespace +std::unique_ptr createBufferizePass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/CMakeLists.txt b/third_party/evas/lib/Transform/Linalg/CMakeLists.txt new file mode 100644 index 0000000000..cb0e66b917 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/CMakeLists.txt @@ -0,0 +1,21 @@ +add_triton_library(EVLinalgTransforms + InsertDeallocOp.cpp + RegionMemAllocator.cpp + MemoryAlloc.cpp + MemoryAllocPass.cpp + RemoveLoopIterArgsWithMemrefType.cpp + SetMemRefScopePass.cpp + RemoveRedundencyCopyPass.cpp + RewriteDataType.cpp + DEPENDS + EVLinalgTransformsIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + MLIRBufferizationTransforms + MLIRLinalgTransforms + EvasLinalgIR +) + +target_compile_options(EVLinalgTransforms PRIVATE -Wno-deprecated-declarations) diff --git a/third_party/evas/lib/Transform/Linalg/DoubleBuffer.cpp b/third_party/evas/lib/Transform/Linalg/DoubleBuffer.cpp new file mode 100644 index 0000000000..1cffbaa82c --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/DoubleBuffer.cpp @@ -0,0 +1,233 @@ +#include "epu/memory.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton-shared/Transform/common_utils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include +#include +#define GEN_PASS_DEF_DOUBLEBUFFER +#include "evas/Transform/Linalg/Passes.h.inc" + +using namespace mlir; +namespace mlir::triton::ev { + +namespace { + +class DoubleBufferPass : public ::impl::DoubleBufferBase { +public: + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + // Walk through all scf.for operations in the function + moduleOp.walk([&](scf::ForOp forOp) { + // Skip if this loop is not suitable for optimization + if (auto target = FindPrefetchTarget(forOp)) + applyPrefetching(forOp, target); + }); + } + +private: + // Check if the loop is suitable for double buffering optimization + Operation *FindPrefetchTarget(scf::ForOp forOp) { + // Check if the loop has constant bounds and step + auto lowerBound = forOp.getLowerBound(); + auto upperBound = forOp.getUpperBound(); + auto step = forOp.getStep(); + + // Check if lower bound is constant 0 + auto lowerBoundOp = lowerBound.getDefiningOp(); + if (!lowerBoundOp) + return nullptr; + + for (auto &op : llvm::reverse(forOp.getBody()->without_terminator())) { + if (auto callOp = dyn_cast(op)) { + if (callOp->hasAttr(mlir::ev::prefetchName) && + callOp->getAttrOfType(mlir::ev::prefetchName) + .getValue()) { + return &op; + } + } + } + + return nullptr; + } + // Clone operations for prefetching before the loop + void clonePrefetchOperations(mlir::scf::ForOp forOp, Operation *target, + mlir::OpBuilder &builder, + IRMapping &prefetchMap) { + + // Get the loop induction variable and its initial value + Value inductionVar = forOp.getInductionVar(); + Value lowerBound = forOp.getLowerBound(); + builder.setInsertionPoint(forOp); + // Map the induction variable to the lower bound for the prefetch operations + prefetchMap.map(inductionVar, lowerBound); + + // Clone operations from the loop body for the first iteration + Block &loopBody = forOp.getRegion().front(); + + for (Operation &op : + llvm::make_range(loopBody.begin(), std::next(target->getIterator()))) { + // Skip the terminator + if (op.hasTrait()) + continue; + // Clone the operation with mapped operands + builder.clone(op, prefetchMap); + } + } + Value findToTensorUser(Operation *op) { + // Iterate through all users of the operation's result + for (Operation *user : op->getResult(0).getUsers()) { + // Check if the user is a ToTensorOp + if (auto toTensorOp = dyn_cast(user)) { + return toTensorOp->getResult(0); + } + } + return nullptr; + } + mlir::scf::ForOp createNewForOp(mlir::scf::ForOp forOp, Operation *target, + mlir::OpBuilder &builder, + mlir::IRMapping &prefetchMap, + mlir::IRMapping &newForOpMap) { + auto loc = forOp.getLoc(); + auto lowerBound = forOp.getLowerBound(); + auto upperBound = forOp.getUpperBound(); + auto step = forOp.getStep(); + builder.setInsertionPoint(forOp); + // TODO(wyann): suporrt multiple outputs + // auto targetValue = target->getResult(0); + // loopArgs.push_back(prefetchMap.lookup(targetValue)); + auto newForOp = builder.create(loc, lowerBound, upperBound, + step, forOp.getInitArgs()); + Block *newloopBody = newForOp.getBody(); + builder.setInsertionPointToStart(newloopBody); + auto prefetchInductor = builder.create( + loc, newForOp.getInductionVar().getType(), newForOp.getInductionVar(), + newForOp.getStep()); + newForOpMap.map(forOp.getInductionVar(), prefetchInductor); + auto condPreftchOutOfBound = builder.create( + loc, mlir::arith::CmpIPredicate::ne, prefetchInductor, + newForOp.getUpperBound()); + auto ifOp = + builder.create(loc, condPreftchOutOfBound, false); + auto ifOpThenBlock = ifOp.thenBlock(); + builder.setInsertionPointToStart(ifOpThenBlock); + auto prefetchInductorMod = builder.create( + loc, prefetchInductor.getType(), prefetchInductor, + builder.create( + loc, prefetchInductor.getType(), 2)); + + auto doubleBufferSelectCond = builder.create( + loc, mlir::arith::CmpIPredicate::eq, prefetchInductorMod, + builder.create( + loc, prefetchInductorMod.getType(), 0)); + // clone before target + llvm::SmallDenseMap pangAllocsMap; + for (Operation &op : llvm::make_range(forOp.getBody()->begin(), + std::next(target->getIterator()))) { + builder.clone(op, newForOpMap); + if (isa(op)) { + auto pang_alloc = newForOpMap.lookup(op.getResult(0)); + auto ping_alloc = prefetchMap.lookup(op.getResult(0)); + auto dbSelect = builder.create( + loc, doubleBufferSelectCond, ping_alloc, pang_alloc); + // hoist the double buffer alloc out of loop + newForOpMap.map(op.getResult(0), dbSelect); + pangAllocsMap[&op] = pang_alloc; + pang_alloc.getDefiningOp()->moveBefore(newForOp); + } + } + builder.setInsertionPointAfter(ifOp); + + // adjust coveredOp inductor value + newForOpMap.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + auto coveredInductorMod = builder.create( + loc, newForOp.getInductionVar().getType(), newForOp.getInductionVar(), + builder.create( + loc, newForOp.getInductionVar().getType(), 2)); + + auto selectCond = builder.create( + loc, mlir::arith::CmpIPredicate::eq, coveredInductorMod, + builder.create( + loc, coveredInductorMod.getType(), 0)); + // adjust buffer selected op + + for (auto [op, pangAlloc] : pangAllocsMap) { + auto toTensorUser = findToTensorUser(op); + auto selectOp = builder.create( + loc, selectCond, prefetchMap.lookup(op->getResult(0)), + pangAlloc); + newForOpMap.map( + toTensorUser == nullptr ? op->getResult(0) : toTensorUser, + builder.create( + loc, memref::getTensorTypeFromMemRefType(selectOp.getType()), selectOp)); + } + + // clone after target + for (Operation &op : + llvm::make_range(std::next(target->getIterator()), + forOp.getBody()->without_terminator().end())) { + + // Clone the operation with mapped operands + builder.clone(op, newForOpMap); + } + return newForOp; + } + + void addFinalIterationCode(mlir::scf::ForOp forOp, mlir::scf::ForOp newForOp, + Operation *target, mlir::OpBuilder &builder, + mlir::IRMapping &newForOpMap) { + builder.setInsertionPointAfter(newForOp); + IRMapping finalIterMap; + + finalIterMap.map(forOp.getInductionVar(), newForOp.getUpperBound()); + // TODO(wyann): suporrt multiple outputs + finalIterMap.map(target->getResult(0), newForOp.getResult(0)); + Block *loopBody = forOp.getBody(); + for (Operation &op : + llvm::make_range(std::next(target->getIterator()), + loopBody->without_terminator().end())) { + builder.clone(op, finalIterMap); + } + } + // Apply the double buffering optimization to the loop + void applyPrefetching(scf::ForOp forOp, Operation *target) { + OpBuilder builder(forOp); + MLIRContext *context = builder.getContext(); + IRMapping prefetchMap; + IRMapping newForOpMap; + // 1. Clone the first iteration before the loop + clonePrefetchOperations(forOp, target, builder, prefetchMap); + // 2. Create a new loop with adjusted bounds + auto newForOp = + createNewForOp(forOp, target, builder, prefetchMap, newForOpMap); + // // 3. Add code for the final iteration after the loop + // addFinalIterationCode(forOp, newForOp, target, builder, newForOpMap); + // 4. Remove the original loop + forOp.erase(); + } +}; + +} // namespace + +std::unique_ptr createDoubleBufferPass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/EncapsulateLinalgOp.cpp b/third_party/evas/lib/Transform/Linalg/EncapsulateLinalgOp.cpp new file mode 100644 index 0000000000..9d06a2a4a0 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/EncapsulateLinalgOp.cpp @@ -0,0 +1,583 @@ +#include + +#include "epu/memory.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" + +#include "triton-shared/Transform/common_utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#define GEN_PASS_DEF_ENCAPSULATELINALGOP +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { +namespace { +using mlir::func::FuncOp; +using mlir::linalg::LinalgOp; + +static constexpr llvm::StringRef kSchedulePrimitive = "schedule_primitive"; + +using Cluster = llvm::SmallVector; +raw_ostream &operator<<(raw_ostream &os, const Cluster &cluster) { + os << "[\n"; + for (size_t i = 0; i < cluster.size(); ++i) { + if (i != 0) { + os << ", "; // Separate elements with a comma + } + if (cluster[i] != nullptr) { + // Assuming you want to output the address of the Operation object + os << "Operation " << i << ": " << *(cluster[i]) + << "\n"; // Output the address (or use any relevant member function) + // Alternatively, you can call a member function to display additional + // information if desired cluster[i]->print(); + } else { + os << "nullptr\n"; // Handle nullptr entries + } + } + os << "]\n"; + return os; +} + +void reorderCluster(Cluster &cluster) { + if (cluster.empty()) + return; + std::sort(cluster.begin(), cluster.end(), + [](Operation *a, Operation *b) { return a->isBeforeInBlock(b); }); +} + +bool isComputationalOp(Operation *op) { + StringRef opName = op->getName().getStringRef(); + return opName.starts_with("linalg") && opName != "linalg.yield" && opName != "linalg.generic"; +} + +bool isOnlyUser(Operation *A, Operation *B) { + // Check if A has any results + if (A->getNumResults() == 0) + return false; + // Check all results of A + for (Value result : A->getResults()) { + // If any result has no uses, return false + if (result.use_empty()) + return false; + // Check if all uses of this result are in operation B + for (OpOperand &use : result.getUses()) { + if (use.getOwner() != B) + return false; + } + } + // All results of A are only used by B + return true; +} + +bool usedOnlyInCluster(Operation *op, const Cluster &cluster) { + for (Value result : op->getResults()) { + for (OpOperand &use : result.getUses()) { + if (!llvm::is_contained(cluster, use.getOwner())) + return false; + } + } + return true; +} + +SmallVector getInputsOfCluster(const Cluster &cluster) { + llvm::SmallVector inputs; + llvm::SmallDenseSet inputSet; + llvm::SmallDenseSet opSet; + for (Operation *op : cluster) { + bool inserted = opSet.insert(op).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + } + + for (Operation *op : cluster) { + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (opSet.find(defOp) != opSet.end()) { + // skip if defining op is in the cluster + continue; + } + if (inputSet.insert(operand).second) { + inputs.push_back(operand); + } + } + } + return inputs; +} + +SmallVector getOutputsOfCluster(const Cluster &cluster) { + llvm::SmallVector outputs; + llvm::SmallDenseSet opSet; + for (Operation *op : cluster) { + // Should add all the operations recursively because a value might be used + // by an operation of an inner region. + op->walk([&](Operation *innerOp) { + bool inserted = opSet.insert(innerOp).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + }); + } + + for (Operation *op : cluster) { + for (Value result : op->getResults()) { + bool hasExternalUser = + llvm::any_of(result.getUses(), [&](OpOperand &use) { + return !opSet.count(use.getOwner()); + }); + if (hasExternalUser) { + outputs.push_back(result); + } + } + } + return outputs; +} + +Operation *getFirstOpInCluster(const Cluster &cluster) { + Operation *firstOp = *std::min_element( + cluster.begin(), cluster.end(), + [](Operation *x, Operation *y) { return x->isBeforeInBlock(y); }); + return firstOp; +} + +Operation *getLastOpInCluster(const Cluster &cluster) { + Operation *lastOp = *std::max_element( + cluster.begin(), cluster.end(), + [](Operation *x, Operation *y) { return x->isBeforeInBlock(y); }); + return lastOp; +} + +void moveConsumer(const Cluster &cluster) { + Operation *firstOp = getFirstOpInCluster(cluster); + Operation *lastOp = getLastOpInCluster(cluster); + + llvm::SmallDenseSet fusedSet(cluster.begin(), cluster.end()); + llvm::SmallDenseSet consumerSet; + + llvm::SmallVector consumersVec; + auto firstIter = firstOp->getIterator(); + auto lastIter = lastOp->getIterator(); + + for (Operation &curOp : llvm::make_range(firstIter, lastIter)) { + // isn't fused op && consumer's op + // move this after fusion op + if (!fusedSet.contains(&curOp)) { + // fused op's consumer or consumer's consumer + bool isConsumer = + llvm::any_of(curOp.getOperands(), [&fusedSet, &consumerSet](Value v) { + auto op = v.getDefiningOp(); + return fusedSet.contains(op) || consumerSet.contains(op); + }); + if (isConsumer) { + consumerSet.insert(&curOp); + consumersVec.push_back(&curOp); + } + } + } + + for (auto op : llvm::reverse(consumersVec)) { + op->moveAfter(lastOp); + } +} + +bool isCallOpRet(Value v) { + auto op = v.getDefiningOp(); + if (!op) + return false; + if (isa(op)) { + return true; + } + for (auto operand : op->getOperands()) { + if (isCallOpRet(operand)) + return true; + } + return false; +} + +bool isScalarType(Type type) { + return isa(type) || isa(type) || + isa(type) || isa(type); +} + +Operation *findDstInput(Operation *op) { + if (!op) + return nullptr; + if (isa_and_nonnull(op)) { + return op; + } + if (isa_and_nonnull(op)) { + return op; + } + if (op->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) { + return findDstInput(op->getOperand(0).getDefiningOp()); + } else if (auto fillOp = llvm::dyn_cast_or_null(op)) { + return findDstInput(fillOp.getOperand(1).getDefiningOp()); + } + return nullptr; +} + +Cluster findDestinationOps(Operation *op, + const llvm::SmallDenseSet &dstInputSet, + const Cluster &cls) { + Cluster destinationOps; + // Process each operand of the operation + for (Value input : op->getOperands()) { + Operation *inputOp = input.getDefiningOp(); + // Skip if this is an output dstInput + if (!inputOp || dstInputSet.contains(inputOp)) + continue; + // Skip if the input op is not used only in the cluster + if (!usedOnlyInCluster(inputOp, cls)) + continue; + // If this is an allocation operation, add it directly + if (isa(inputOp)) { + destinationOps.push_back(inputOp); + continue; + } + // Recursively process tensor dialect operations + if (inputOp->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) { + // Get destination ops from the tensor op + Cluster attachedOps = findDestinationOps(inputOp, dstInputSet, cls); + // If we found destination ops, add them and the tensor op + if (!attachedOps.empty()) { + destinationOps.append(attachedOps); + destinationOps.push_back(inputOp); + } + } + } + + return destinationOps; +} + +std::string getFuncName(int clusterIdx) { + std::ostringstream nameStream; + nameStream << "sub_kernel_" << clusterIdx; + return nameStream.str(); +} + +class EncapsulateLinalgOpPass + : public ::impl::EncapsulateLinalgOpBase { +public: + int getClusterIndex(Operation *op) { + for (auto indexedCluster : llvm::enumerate(clusters)) { + auto cluster = indexedCluster.value(); + for (auto clusterOp : cluster) { + if (clusterOp == op) + return indexedCluster.index(); + } + } + return -1; + } + + int mergeClusters(int c1, int c2) { + if (c1 == c2) + return c1; + if (c1 > c2) { + return mergeClusters(c2, c1); + } + Cluster cluster2 = clusters[c2]; + clusters.erase(clusters.begin() + c2); + clusters[c1].append(cluster2.begin(), cluster2.end()); + return c1; + } + + void InitClusters(FuncOp funcOp) { + funcOp.walk([this](Operation *op) { + if (isComputationalOp(op)) { + Cluster newCls; + newCls.push_back(op); + clusters.push_back(newCls); + } + return WalkResult::advance(); + }); + } + + bool ConnnectedTo(const Cluster &clsA, const Cluster &clsB) { + auto valuesA = getOutputsOfCluster(clsA); + auto valuesB = getInputsOfCluster(clsB); + for (auto in : valuesA) { + for (auto out : valuesB) { + if (in == out) + return true; + } + } + return false; + } + + bool HasSingleOutput(const Cluster &clsA) { + auto outputs = getOutputsOfCluster(clsA); + return outputs.size() <= 1; + } + + Cluster getMergedCls(const Cluster &clsA, const Cluster &clsB) { + Cluster ret = clsA; + ret.append(clsB.begin(), clsB.end()); + return ret; + } + + template bool hasLinalgOp(const Cluster &cls) { + for (auto op : cls) { + if (isa(op)) + return true; + } + return false; + } + + template bool IsolatedPattern(const Cluster &cls) { + if (!hasLinalgOp(cls)) + return true; + for (auto op : cls) { + if (!(isa(op) || isa(op))) + return false; + } + return true; + } + + template bool TryIsolatedPattern(const Cluster &cls) { + return (... && IsolatedPattern(cls)); + } + + bool TryRestrictedFusePattern(const Cluster &cls) { + // Multi-output is not supported by evofc for now + if (!HasSingleOutput(cls)) + return false; + if (!TryIsolatedPattern( + cls)) + return false; + // transpose + return true; + } + + bool CanFuseTo(const Cluster &clsA, const Cluster &clsB) { + if (!ConnnectedTo(clsA, clsB)) + return false; + + Operation *lastOpA = getLastOpInCluster(clsA); + Operation *firstOpB = getFirstOpInCluster(clsB); + Cluster clsMid; + // Check if lastOpA and firstOpB are in the same block + if (lastOpA->getBlock() != firstOpB->getBlock()) + return false; + for (auto &op : llvm::make_range(std::next(lastOpA->getIterator()), + firstOpB->getIterator())) { + clsMid.push_back(&op); + } + // annotate op to cut the fusion + if (utils::getAnnotation(lastOpA)) + return false; + // check if the middle cluster is connected to both clsA and clsB + if (!clsMid.empty() && ConnnectedTo(clsA, clsMid) && + ConnnectedTo(clsMid, clsB)) { + return false; + } + auto tryMerged = getMergedCls(clsA, clsB); + return TryRestrictedFusePattern(tryMerged); + } + + void FuseClustersWithDefUse() { + if (clusters.size() <= 1) + return; + for (size_t index = 0; index < clusters.size() - 1; ++index) { + if (CanFuseTo(clusters[index], clusters[index + 1])) { + (void)mergeClusters(index, index + 1); + FuseClustersWithDefUse(); + return; + } + } + } + + Cluster getAttachedCluster(const Cluster &cls) { + Cluster ret = cls; + auto outputs = getOutputsOfCluster(cls); + auto outputsSet = + llvm::SmallDenseSet(outputs.begin(), outputs.end()); + llvm::SmallDenseSet dstInputSet; + // find all the dst input ops that correspond to the cluster outputs + for (auto op : cls) { + auto dstOp = cast(op); + for (auto [idx, output] : llvm::enumerate(op->getResults())) { + if (outputsSet.contains(output)) { + auto init = + findDstInput(dstOp.getDpsInitOperand(idx)->get().getDefiningOp()); + dstInputSet.insert(init); + outputToDstInput[output] = init; + } + } + } + for (auto op : cls) { + auto dst_ops = findDestinationOps(op, dstInputSet, cls); + ret.append(dst_ops); + } + reorderCluster(ret); + return ret; + } + + void setMemscopeForAllocTensorOp(bufferization::AllocTensorOp allocOp, + OpBuilder &b) { + allocOp.setMemorySpaceAttr( + b.getI64IntegerAttr((int64_t)mlir::triton::MemScope::MM)); + if (isScalarType(allocOp.getType())) { + allocOp.setMemorySpaceAttr( + b.getI64IntegerAttr((int64_t)mlir::triton::MemScope::DDR)); + } + } + + bool isAnnotatedPrefetch(Operation *op) { + if (auto annotateOp = utils::getAnnotation(op)) { + auto meminfo = annotateOp.getMeminfo(); + return meminfo.getPrefetch(); + } + return false; + } + + void annotatePrefetchToSubKernel(bufferization::AllocTensorOp allocOp, + func::CallOp callOp, OpBuilder &b) { + if (isAnnotatedPrefetch(allocOp)) { + callOp->setAttr(mlir::ev::prefetchName, b.getBoolAttr(true)); + } + } + + func::FuncOp createFuncOpWithCluster(OpBuilder &b, StringRef subFnName, + ValueRange inputs, ValueRange outputs, + const Cluster &cluster, + Operation *insertionPoint) { + Operation *lastOp = getLastOpInCluster(cluster); + llvm::SmallVector locations; + locations.reserve(cluster.size()); + for (Operation *op : cluster) { + locations.push_back(op->getLoc()); + } + Location fusedLoc = FusedLoc::get(lastOp->getContext(), locations); + + llvm::SmallVector outputTypes; + outputTypes.reserve(outputs.size()); + for (Value v : outputs) { + outputTypes.push_back(v.getType()); + } + llvm::SmallVector inputTypes; + inputTypes.reserve(inputs.size()); + for (Value v : inputs) { + inputTypes.push_back(v.getType()); + } + + moveConsumer(cluster); + + auto subFnType = b.getFunctionType(inputTypes, outputTypes); + b.setInsertionPoint(insertionPoint); + func::FuncOp subFnOp = + b.create(fusedLoc, subFnName, subFnType); + subFnOp.setSymVisibility("private"); + b.setInsertionPoint(lastOp); + auto callOp = b.create(fusedLoc, subFnOp, inputs); + callOp->setAttr(kSchedulePrimitive, b.getBoolAttr(true)); + // callOp->setAttr(mlir::ev::addrName, + // b.getArrayAttr(SmallVector( + // callOp.getNumResults(), b.getI64IntegerAttr(-1)))); + Block *block = subFnOp.addEntryBlock(); + b.setInsertionPoint(block, block->end()); + IRMapping bvm; + for (auto inputAndArg : llvm::zip(inputs, subFnOp.getArguments())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + for (Operation *op : cluster) { + b.clone(*op, bvm); + } + llvm::SmallVector funcReturns; + for (Value output : outputs) { + funcReturns.push_back(bvm.lookupOrDefault(output)); + } + b.create(fusedLoc, funcReturns); + + for (auto outputAndResult : llvm::zip(outputs, callOp.getResults())) { + Value output = std::get<0>(outputAndResult); + // replace the use of output with the destination alloc op + Operation *dstInputOp = outputToDstInput[output]; + Value dstInputValue = dstInputOp->getResult(0); + Value callResult = std::get<1>(outputAndResult); + for (OpOperand &use : llvm::make_early_inc_range(output.getUses())) { + use.set(dstInputValue); + } + // 只对 bufferization::AllocTensorOp 设置内存空间和预取属性 + if (auto tensorAllocOp = + dyn_cast(dstInputOp)) { + setMemscopeForAllocTensorOp(tensorAllocOp, b); + // todo:需要考虑totensor的情况 + // annotatePrefetchToSubKernel(tensorAllocOp, callOp, b); + } + } + + // erase dead ops in the end + for (Operation *op : llvm::reverse(cluster)) { + if (op->use_empty()) { + op->erase(); + } + } + + return subFnOp; + } + FailureOr createFuncOpWithCluster(OpBuilder &b, + StringRef subFnName, + const Cluster &cluster, + Operation *insertionPoint) { + auto attachedCluster = getAttachedCluster(cluster); + llvm::SmallVector inputs = getInputsOfCluster(attachedCluster); + llvm::SmallVector outputs = getOutputsOfCluster(attachedCluster); + return createFuncOpWithCluster(b, subFnName, inputs, outputs, + attachedCluster, insertionPoint); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + // set funcName fixed for finding the outer func in further process + auto f = *(m.getOps().begin()); + const std::string funcName = "kernel"; + f.setName(funcName); + InitClusters(f); + OpBuilder b(f); + + SymbolTable symTable(m); + for (auto c : llvm::enumerate(clusters)) { + FailureOr subFnOp = createFuncOpWithCluster( + b, getFuncName(c.index()), c.value(), f.getOperation()); + assert(mlir::succeeded(subFnOp) && "create FuncOp failed"); + symTable.insert(*subFnOp); + } + } + +private: + SmallVector clusters; + llvm::SmallDenseMap outputToDstInput; +}; + +} // namespace + +std::unique_ptr createEncapsulateLinalgOpPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/InsertDeallocOp.cpp b/third_party/evas/lib/Transform/Linalg/InsertDeallocOp.cpp new file mode 100644 index 0000000000..ca434eff03 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/InsertDeallocOp.cpp @@ -0,0 +1,93 @@ +#include "epu/memory.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#define GEN_PASS_DEF_INSERTDEALLOCOP +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { + +namespace { +/// A pass to insert deallocations for allocated buffers after theirlast use. +using namespace mlir; +struct InsertDeallocOpPass + : public ::impl::InsertDeallocOpBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + for (auto func : module.getOps()) { + if (func.isDeclaration() || func.getSymName() != "kernel") + continue; + Liveness liveness(func); + SmallVector allocValues; + func.walk([&](Operation *op) { + if (mlir::ev::isMemoryAllocOp(op)) { + allocValues.push_back(op->getResult(0)); + } else if (mlir::ev::isSubkernelBufferOp(op)) { + for (auto result : op->getResults()) { + allocValues.push_back(result); + } + } + }); + OpBuilder builder(func.getBody()); + for (auto allocValue : allocValues) { + if (Operation *lastUser = + findLastUserInOneBlock(allocValue, liveness)) { + builder.setInsertionPointAfter(lastUser); + if (auto memrefAlloc = + dyn_cast(allocValue.getOwner())) { + builder.create(lastUser->getLoc(), memrefAlloc); + } else { + builder.create(lastUser->getLoc(), + allocValue); + } + } + } + } + } + +private: + bool obtain(Value src, Operation *dst, Block *block) { + for (Operation *useOp : src.getUsers()) { + useOp = block->findAncestorOpInBlock(*useOp); + if (!useOp) + continue; + if (useOp == dst) { + return true; + } + if ((isa(useOp) || + useOp->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) && + obtain(useOp->getResult(0), dst, block)) { + return true; + } + } + return false; + } + + // liveness analysis within multiple blocks is not supported for now + Operation *findLastUserInOneBlock(OpResult allocValue, Liveness liveness) { + Block *block = allocValue.getOwner()->getBlock(); + auto liveInfo = liveness.getLiveness(block); + assert(!liveInfo->isLiveOut(allocValue) && + "liveness within multiple blocks is not supported"); + for (Operation &curOp : llvm::reverse(llvm::make_range( + allocValue.getOwner()->getIterator(), block->without_terminator().end()))) { + if (obtain(allocValue, &curOp, block)) { + return &curOp; + } + } + return nullptr; + } +}; +} // namespace +std::unique_ptr createInsertDeallocOpPass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev \ No newline at end of file diff --git a/third_party/evas/lib/Transform/Linalg/MaterializeAnnotation.cpp b/third_party/evas/lib/Transform/Linalg/MaterializeAnnotation.cpp new file mode 100644 index 0000000000..537fd884b6 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/MaterializeAnnotation.cpp @@ -0,0 +1,119 @@ +#include "epu/memory.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +#define GEN_PASS_DEF_MATERIALIZEANNOTATION +#include "evas/Transform/Linalg/Passes.h.inc" +namespace mlir::triton::ev { + +namespace { +/// A pass to insert deallocations for allocated buffers after theirlast use. +using namespace mlir; + +// Modify memory scope for connected memref values recursively +void InferMemrefType(Value value, bool onlyScope = true) { + // Get the type of input value + Type inputType = value.getType(); + + // Get all users of this value + for (Operation *user : value.getUsers()) { + // Only infer op with ViewLikeOpInterface + if (!dyn_cast(user)) + continue; + + // For each result of the user operation + for (Value result : user->getResults()) { + // If result type can be cast to MemRefType + if (MemRefType resultMemRef = dyn_cast(result.getType())) { + // Create new memref type with same properties but input memory scope + + if (MemRefType inputMemRef = dyn_cast(inputType)) { + if (onlyScope) { + auto newType = MemRefType::get( + resultMemRef.getShape(), resultMemRef.getElementType(), + resultMemRef.getLayout(), + IntegerAttr::get(IntegerType::get(value.getContext(), 64), + (int64_t)inputMemRef.getMemorySpaceAsInt())); + // Update the result type + result.setType(newType); + } else { + result.setType(inputMemRef); + } + // Recursively update memory scope for connected values + InferMemrefType(result, onlyScope); + } + } + } + } +} + +void foldMemrefCopy(ModuleOp moduleOp) { + moduleOp.walk([&](memref::CopyOp copyOp) { + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + + auto srcType = dyn_cast(src.getType()); + auto dstType = dyn_cast(dst.getType()); + + if (!srcType || !dstType || dstType != srcType) + return; + // Replace all uses of dst with src and erase the copy op + dst.replaceAllUsesWith(src); + copyOp.erase(); + }); +} + +struct MaterializeAnnotationPass + : public ::impl::MaterializeAnnotationBase { + + void materializeAddress(tts::AnnotateOp annotateOp, OpBuilder &builder) { + Value input = annotateOp.getSrc(); + auto allocOp = input.getDefiningOp(); + if (!allocOp) + return; + + auto memInfoAttr = annotateOp.getMeminfoAttr(); + if (!memInfoAttr) + return; + + auto memrefType = cast(allocOp.getType()); + int memScope = (int)memInfoAttr.getScope(); + auto address = memInfoAttr.getAddress(); + + // memscope should be handled already + assert(memScope == memrefType.getMemorySpaceAsInt()); + + if (address > 0) { + allocOp->setDiscardableAttr(mlir::ev::phyAddrName, + builder.getI64IntegerAttr(address)); + } + + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + OpBuilder builder(moduleOp.getContext()); + + // Walk through all AnnotateOp operations in the module + moduleOp.walk([&](tts::AnnotateOp annotateOp) { + materializeAddress(annotateOp, builder); + }); + + // Remove all annotate ops + moduleOp.walk([&](tts::AnnotateOp annotateOp) { annotateOp.erase(); }); + + } +}; + +} // namespace +std::unique_ptr createMaterializeAnnotationPass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/MemoryAlloc.cpp b/third_party/evas/lib/Transform/Linalg/MemoryAlloc.cpp new file mode 100644 index 0000000000..a589bf72c7 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/MemoryAlloc.cpp @@ -0,0 +1,139 @@ +//===----------------------- MemoryAlloc.cpp --------------------*- C++ -*-===// +// +// Copyright 2024 EVAS Intelligence Co.,Ltd. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#include "evas/Transform/Linalg/MemoryAlloc.h" + +#define DEBUG_TYPE "memory-alloc" + +namespace mlir::triton::ev { + +//===----------------------------------------------------------------------===// +// BFS Region Visitor +//===----------------------------------------------------------------------===// + +void BFSRegionVisitor::pushSubRegionAndUpdate( + Region *visitedRegion, const std::shared_ptr MA) { + Builder builder(visitedRegion->getParentOp()); + std::set opsToUpdate; + + visitedRegion->walk([&](Operation *op) { + if (op->getParentRegion() != visitedRegion) + return; + if (isa(op)) { + opsToUpdate.insert(op); + } else if (op->getNumRegions() != 0) { + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + Region *subRegion = &op->getRegion(i); + if (!subRegion->empty()) { + opsToUpdate.insert(op); + toVisitRegion.push(subRegion); + } + } + } + }); + + for (auto op : opsToUpdate) { + std::vector liveAttrs; + for (const auto &phyBuffer : MA->getAllocResult()) { + if (phyBuffer->isOverflow() || phyBuffer->isSubRegionBuf()) + continue; + if (phyBuffer->isExternal() || phyBuffer->isLiveAt(op)) { + auto liveInfo = + builder.getI64ArrayAttr({phyBuffer->addr(), phyBuffer->size()}); + liveAttrs.push_back(liveInfo); + } + } + if (liveAttrs.empty()) + continue; + Attribute liveAttrArrayAttr = builder.getArrayAttr(liveAttrs); + op->setAttr(LiveBufString(MA->getScope()), liveAttrArrayAttr); + // Update living information at 'func::FuncOp' + if (update && isa(op)) { + func::FuncOp func = getCalledFunction(cast(op)); + if (func && !func.isDeclaration()) + func->setAttr(LiveBufString(MA->getScope()), liveAttrArrayAttr); + } + } +} + +void BFSRegionVisitor::visit(Region *rootRegion, const Liveness &LN, + const std::shared_ptr MA) { + assert(toVisitRegion.empty() && "Unexpected queue state"); + toVisitRegion.push(rootRegion); + while (!toVisitRegion.empty()) { + Region *visitedRegion = toVisitRegion.front(); + toVisitRegion.pop(); + MA->reset(); + // 1. Init memory allocator + MA->init(visitedRegion, LN); + // 2. Assign physical address for all buffers + MA->allocate(); + // 3. Rewrite physical address to 'isMemoryAllocOp' operators + MA->rewrite(); + // 4. Update living information at 'func::CallOp' and 'opHasSubRegion' + pushSubRegionAndUpdate(visitedRegion, MA); + } +} + +//===----------------------------------------------------------------------===// +// Memory Allocation Implement +//===----------------------------------------------------------------------===// + +void MemoryAllocImpl::runOnFuncAtScope(MemScope memScope, + FunctionOpInterface func, + const Liveness &LN) { + Builder builder(func.getOperation()); + // Prologue: + // 1) remove 'phyAddr' attribute if 'preview = true' + // 2) remove 'overflow' and 'preview' attribute + std::vector fixedBufferAttrs; + func.getFunctionBody().walk([&](Operation *op) { + if (!isMemoryAllocOp(op) || getMemScope(op) != memScope) + return; + if (op->hasAttr(previewName) && + op->getAttrOfType(previewName).getValue()) + op->removeAttr(phyAddrName); + op->removeAttr(previewName); + op->removeAttr(overflowName); + }); + + // Main : Alloc memory for current function + auto DMA = std::make_shared(ToAssignOrder, memScope, + alignment, preview); + BfsRV.visit(&func.getFunctionBody(), LN, DMA); + + // BankOpt: Realloc buffers with bank-alone preference + if (bankopt && memHasBank(memScope)) { + auto MRA = std::make_shared( + ToAssignOrder, memScope, memBankAlignment(memScope), preview); + BfsRV.visit(&func.getFunctionBody(), LN, MRA); + } + + // Epilogue : remove living information in sub regions at release version + LLVM_DEBUG(return;); + func.getFunctionBody().walk([&](Operation *op) { + if (op->getNumRegions() != 0) + op->removeAttr(LiveBufString(memScope)); + }); +} + +void MemoryAllocImpl::runOnFunction(FunctionOpInterface func, + const Liveness &LN) { + // Allocate memory from high-level to low-level + for (size_t id = MemScope::MAX - 1; id > MemScope::UNKNOWN; id--) + runOnFuncAtScope(MemScope(id), func, LN); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/MemoryAllocPass.cpp b/third_party/evas/lib/Transform/Linalg/MemoryAllocPass.cpp new file mode 100644 index 0000000000..8ec13e9385 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/MemoryAllocPass.cpp @@ -0,0 +1,141 @@ +#include "epu/memory.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "evas/Transform/Linalg/MemoryAlloc.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#define GEN_PASS_DECL_MEMORYALLOC +#define GEN_PASS_DEF_MEMORYALLOC +#include "evas/Transform/Linalg/Passes.h.inc" + +#define DEBUG_TYPE "ev-memory-alloc" + +namespace mlir::triton::ev { + +namespace { +using namespace mlir; + +// TODO(wyann): remove this in the future +void setLiveL2ForTransposeCluster(func::CallOp callOp) { + OpBuilder b(callOp); + auto liveL2Str = mlir::ev::LiveBufString(mlir::ev::L2); + auto liveL2Attr = b.getArrayAttr( + {b.getI64ArrayAttr({0, mlir::ev::memCapacity(mlir::ev::L2)})}); + auto liveL1Str = mlir::ev::LiveBufString(mlir::ev::MM); + auto liveL1Attr = b.getArrayAttr( + {b.getI64ArrayAttr({0, mlir::ev::memCapacity(mlir::ev::MM)})}); + callOp->setAttr(liveL2Str, liveL2Attr); + callOp->setAttr(liveL1Str, liveL1Attr); + auto funcOp = mlir::ev::getCalledFunction(callOp); + funcOp->setAttr(liveL2Str, liveL2Attr); + funcOp->setAttr(liveL1Str, liveL1Attr); + return; +} +struct MemoryAllocPass : public ::impl::MemoryAllocBase { + // MemoryAllocPass(size_t memScope, size_t alignment, bool preview, + // bool bankopt) + // : :MemoryAllocBase() { + // this->memScope = memScope; + // this->alignment = alignment; + // this->preview = preview; + // this->bankopt = bankopt; + // } + using MemoryAllocBase::MemoryAllocBase; + + void runOnOperation() override { + const CallGraph &CG = getAnalysis(); + const Liveness &LN = getAnalysis(); + + LLVM_DEBUG(CG.dump()); + + std::set visitedNodes; + std::queue toVisitNodes; + + const CallGraphNode *extCallerNode = CG.getExternalCallerNode(); + for (const CallGraphNode::Edge &edge : *extCallerNode) { + assert(edge.isAbstract() && "Unexpected edge from externel node"); + CallGraphNode *calleeNode = edge.getTarget(); + if (calleeNode->isExternal()) + continue; + Operation *callee = calleeNode->getCallableRegion()->getParentOp(); + assert(isa(callee) && "Unexpected operation"); + // FIXME: Why private function can be called from externel node? + // CallGraph Analysis should be improved. + if (cast(callee).isPublic()) + toVisitNodes.push(calleeNode); + } + + while (!toVisitNodes.empty()) { + CallGraphNode *toVisitNode = toVisitNodes.front(); + toVisitNodes.pop(); + assert(!visitedNodes.count(toVisitNode) && "Node can't be visited twice"); + visitedNodes.insert(toVisitNode); + for (const CallGraphNode::Edge &edge : *toVisitNode) { + assert(edge.isCall() && "TODO: Support children node"); + CallGraphNode *calleeNode = edge.getTarget(); + if (calleeNode->isExternal()) + continue; + toVisitNodes.push(calleeNode); + } + + AllocPolicy policy = {alignment.getValue(), SIZE_PRIOR}; + MemoryAllocImpl MAI(policy, preview.getValue(), /* update callee */ true, + bankopt.getValue()); + Operation *op = toVisitNode->getCallableRegion()->getParentOp(); + assert(isa(op) && "Unexpected operation"); + auto func = cast(op); + if (func.isPublic()) { + if (memScope == 0) { + // Walk over all functions and set memory allocation result. + MAI.runOnFunction(func, LN); + } else { + MAI.runOnFuncAtScope(MemScope(memScope.getValue()), func, LN); + } + } + } + + // TODO(wyann): very tricky code here, delete in the future + // Iterate through all call ops in the module to find transpose subkernel + getOperation()->walk([&](func::CallOp callOp) { + auto calledFunc = mlir::ev::getCalledFunction(callOp); + if (!calledFunc) + return; + bool hasTranspose = false; + calledFunc->walk( + [&](linalg::TransposeOp transposeOp) { hasTranspose = true; }); + if (!hasTranspose) + return; + // Check if the function has only MM scope attribute + if (auto memScopeAttr = + callOp->getAttrOfType(mlir::ev::MEMSCOPE)) { + for (auto memScope : memScopeAttr) { + if (cast(memScope).getInt() != mlir::ev::MM) + return; + } + } else { + return; + } + setLiveL2ForTransposeCluster(callOp); + }); + } +}; +} // namespace + +std::unique_ptr createMemoryAllocPass() { + return std::make_unique(); +} + +std::unique_ptr createMemoryAllocPass(size_t memScope, size_t alignment, + bool preview, bool bankopt) { + MemoryAllocOptions opts{memScope, alignment, preview, bankopt}; + return std::make_unique(opts); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/MemoryPromotionPass.cpp b/third_party/evas/lib/Transform/Linalg/MemoryPromotionPass.cpp new file mode 100644 index 0000000000..4af7db354a --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/MemoryPromotionPass.cpp @@ -0,0 +1,161 @@ +#include "epu/memory.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define GEN_PASS_DEF_MEMORYPROMOTION +#include "evas/Transform/Linalg/Passes.h.inc" +namespace mlir::triton::ev { + +namespace { + +LogicalResult promoteOutput(Operation *op, unsigned outputIndex, + mlir::ev::MemScope srcScope, + mlir::ev::MemScope dstScope, + PatternRewriter &rewriter) { + if (op->getNumOperands() <= outputIndex) { + return failure(); + } + + Value outputOperand = op->getOperand(outputIndex); + MemRefType memRefType = dyn_cast(outputOperand.getType()); + if (!memRefType) { + return failure(); + } + + mlir::ev::MemScope currentScope = mlir::ev::getMemScope(memRefType); + if (currentScope != srcScope) { + return failure(); + } + + MemRefType promotedType = MemRefType::get( + memRefType.getShape(), memRefType.getElementType(), + memRefType.getLayout(), + rewriter.getI64IntegerAttr(static_cast(dstScope))); + + auto promotedAlloc = + rewriter.create(op->getLoc(), promotedType); + Value promotedValue = Value(promotedAlloc.getResult()); + + op->setOperand(outputIndex, promotedValue); + rewriter.setInsertionPointAfter(op); + rewriter.create(op->getLoc(), promotedValue, outputOperand); + + return success(); +} + +LogicalResult promoteInput(Operation *op, unsigned inputIndex, + mlir::ev::MemScope srcScope, + mlir::ev::MemScope dstScope, + PatternRewriter &rewriter) { + if (op->getNumOperands() <= inputIndex) { + return failure(); + } + + Value inputOperand = op->getOperand(inputIndex); + MemRefType memRefType = dyn_cast(inputOperand.getType()); + if (!memRefType) { + return failure(); + } + + mlir::ev::MemScope currentScope = mlir::ev::getMemScope(memRefType); + if (currentScope != srcScope) { + return failure(); + } + + MemRefType promotedType = MemRefType::get( + memRefType.getShape(), memRefType.getElementType(), + memRefType.getLayout(), + rewriter.getI64IntegerAttr(static_cast(dstScope))); + + auto promotedAlloc = + rewriter.create(op->getLoc(), promotedType); + rewriter.create(op->getLoc(), inputOperand, promotedAlloc); + + op->setOperand(inputIndex, promotedAlloc); + + return success(); +} + +struct MatmulMemoryPromotionPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + return promoteOutput(matmulOp.getOperation(), 2, mlir::ev::MemScope::MM, + mlir::ev::MemScope::FAM, rewriter); + } +}; + +struct PromoteLinalgOperandsToMMPattern : public RewritePattern { + PromoteLinalgOperandsToMMPattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto *dialect = op->getDialect(); + if (!dialect || + dialect->getNamespace() != linalg::LinalgDialect::getDialectNamespace()) + return failure(); + // (TODO) Tricky code here, since takeop and scatterop's first input always from ddr + if (isa(op)) return failure(); + bool promotedAny = false; + auto linalgOp = dyn_cast(op); + unsigned numInputs = + linalgOp ? linalgOp.getNumDpsInputs() : op->getNumOperands(); + + for (auto it : llvm::enumerate(op->getOperands())) { + Value operand = it.value(); + auto memRefType = dyn_cast(operand.getType()); + if (!memRefType) + continue; + + mlir::ev::MemScope currentScope = mlir::ev::getMemScope(memRefType); + if (currentScope >= mlir::ev::MemScope::MM) + continue; + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + LogicalResult result = failure(); + if (it.index() < numInputs) { + result = promoteInput(op, it.index(), currentScope, + mlir::ev::MemScope::MM, rewriter); + } else { + result = promoteOutput(op, it.index(), currentScope, + mlir::ev::MemScope::MM, rewriter); + } + if (succeeded(result)) { + promotedAny = true; + } + } + + return promotedAny ? success() : failure(); + } +}; +} // namespace + +struct MemoryPromotionPass + : public ::impl::MemoryPromotionBase { + using MemoryPromotionBase::MemoryPromotionBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add( + &getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +std::unique_ptr createMemoryPromotionPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/RegionMemAllocator.cpp b/third_party/evas/lib/Transform/Linalg/RegionMemAllocator.cpp new file mode 100644 index 0000000000..c114c5a0cc --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/RegionMemAllocator.cpp @@ -0,0 +1,401 @@ +//===------------------- RegionMemAllocator.cpp -----------------*- C++ -*-===// +// +// Copyright 2024 EVAS Intelligence Co.,Ltd. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "evas/Transform/Linalg/RegionMemAllocator.h" +#include "epu/memory.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Operation.h" +#include +#include +#include + +namespace mlir::triton::ev { + +//===----------------------------------------------------------------------===// +// Basic Memory Allocator +//===----------------------------------------------------------------------===// + +void MemAllocator::assignAddrOrder( + const std::shared_ptr visitedBuffer, int64_t align) { + int64_t nextToAssignAddr = 0; + for (const auto &phyBuffer : phyBufsOrder) { + if (phyBuffer->isOverflow()) + continue; + if (phyBuffer->isExternal() || + phyBuffer->isConflictWith(visitedBuffer->getLiveInterval())) { + if (nextToAssignAddr + visitedBuffer->size() <= phyBuffer->addr()) + break; + int64_t upBound = llvm::alignTo(phyBuffer->upperBound(), align); + nextToAssignAddr = std::max(upBound, nextToAssignAddr); + } + } + visitedBuffer->setPhyAddr(nextToAssignAddr); + insertPhyBuf(visitedBuffer); +} + +void MemAllocator::assignAddrReverse( + const std::shared_ptr visitedBuffer, int64_t align) { + // FIXME: Improve next addresee computation with real size. + int64_t alignedBufSize = llvm::alignTo(visitedBuffer->size(), alignment); + int64_t nextToAssignAddr = + memCapacity(visitedBuffer->scope()) - alignedBufSize; + for (const auto &phyBuffer : phyBufsReverse) { + if (phyBuffer->isOverflow()) + continue; + if (phyBuffer->isExternal() || + phyBuffer->isConflictWith(visitedBuffer->getLiveInterval())) { + if (nextToAssignAddr >= phyBuffer->upperBound()) + break; + nextToAssignAddr = std::min(nextToAssignAddr, + phyBuffer->addr() - alignedBufSize); + } + } + visitedBuffer->setPhyAddr(nextToAssignAddr); + insertPhyBuf(visitedBuffer); +} + +void MemAllocator::insertPhyBuf(const std::shared_ptr buffer) { + phyBufsOrder.insert(buffer); + phyBufsReverse.insert(buffer); +} + +void MemAllocator::erasePhyBuf(const std::shared_ptr buffer) { + phyBufsOrder.erase(buffer); + phyBufsReverse.erase(buffer); +} + +void MemAllocator::revertPhyBuf(const std::shared_ptr buffer) { + erasePhyBuf(buffer); + buffer->setPhyAddr(0); + virtualBufs.insert(buffer); +} + +std::shared_ptr MemAllocator::pickNextBuffer() { + assert(!virtualBufs.empty() && "No buffer to pick"); + auto visitedBuffer = *virtualBufs.begin(); + virtualBufs.erase(visitedBuffer); + return visitedBuffer; +} + +void MemAllocator::initExternelBufs(const ArrayAttr &externalLives, + uint64_t &nextSlotIndex) { + // Init externel buffers using living information at attribute. + for (auto liveBuffer : externalLives) { + ArrayAttr liveInfo = cast(liveBuffer); + int64_t phyAddr = cast(liveInfo[0]).getInt(); + int64_t size = cast(liveInfo[1]).getInt(); + auto phyBuf = std::make_shared(memScope, phyAddr, size); + phyBuf->setSlotIndex(nextSlotIndex++); + insertPhyBuf(phyBuf); + } + // Checking there is no externel buffer overlap with another. + auto isVaildInitBuffer = [](const LiveBufferSet &phyBufsOrder) { + int64_t slidePtr = 0; + for (const auto &phyBuffer : phyBufsOrder) { + if (!phyBuffer->isExternal()) + continue; + if (phyBuffer->isOverflow() || phyBuffer->addr() < slidePtr) + return false; + slidePtr = phyBuffer->upperBound(); + } + return true; + }; + assert(isVaildInitBuffer(phyBufsOrder) && "Unexpected initial state"); +} + +// Warning: There is no guarantee for the validity of fixed physical buffers. +void MemAllocator::initCurrRegionPhyBufs( + const std::vector &opsWithFixedAddr, const Liveness &LN, + uint64_t &nextSlotIndex) { + for (const auto &op : opsWithFixedAddr) { + auto fixedBuf = std::make_shared(op); + // set phyAddr, slotIndex, priority and liveness + int64_t addr = op->getAttrOfType(phyAddrName).getInt(); + fixedBuf->setPhyAddr(addr); + fixedBuf->setSlotIndex(nextSlotIndex++); + fixedBuf->setPriority(getMemoryPrior(op)); + // TODO: Improve Live-interval to speed up analysis + auto liveInterval = LN.resolveLiveness(fixedBuf->value()); + fixedBuf->setLiveInterval(liveInterval); + insertPhyBuf(fixedBuf); + } +} + +void MemAllocator::initSubRegionPhyBufs( + const std::vector &opsHasSubRegion, uint64_t &nextSlotIndex) { + // Create a merged buffer and insert this to physical buffer set + auto initMergedBuffer = [&](Operation *op, int64_t addr, int64_t upper, + uint64_t &slotIndex) { + int64_t size = upper - addr; + auto mergedBuffer = std::make_shared(op, memScope, addr, size); + mergedBuffer->setLiveInterval(op); + mergedBuffer->setSlotIndex(slotIndex++); + insertPhyBuf(mergedBuffer); + }; + for (const auto &opHasSubRegion : opsHasSubRegion) { + // 1. Collect all of allocation operations nested under the given sub region + LiveBufferSet toMergeBuffers(LayoutOrder); + opHasSubRegion->walk([&](Operation *op) { + if (!isMemoryAllocOp(op) || getMemScope(op) != memScope) + return; + if (op->hasAttr(overflowName) && + op->getAttrOfType(previewName).getValue()) + return; + if (!op->hasAttr(phyAddrName)) + return; + auto fixedBuf = std::make_shared(op); + int64_t addr = op->getAttrOfType(phyAddrName).getInt(); + fixedBuf->setPhyAddr(addr); + toMergeBuffers.insert(fixedBuf); + }); + if (toMergeBuffers.empty()) + continue; + // 2. Try to merge all buffers and create fake physical buffer with liveness + int64_t mergedBufAddr = -1; + int64_t mergedBufUpper = -1; + for (const auto &toMergeBuffer : toMergeBuffers) { + if (mergedBufAddr == -1 || mergedBufUpper == -1) { + mergedBufAddr = toMergeBuffer->addr(); + mergedBufUpper = toMergeBuffer->upperBound(); + } + if (toMergeBuffer->addr() > mergedBufUpper && + mergedBufUpper > mergedBufAddr) { + initMergedBuffer(opHasSubRegion, mergedBufAddr, mergedBufUpper, + nextSlotIndex); + mergedBufAddr = toMergeBuffer->addr(); + mergedBufUpper = toMergeBuffer->upperBound(); + continue; + } + mergedBufUpper = + std::max(toMergeBuffer->upperBound(), mergedBufUpper); + } + if (mergedBufAddr != -1 && mergedBufUpper != -1) { + assert(mergedBufUpper > mergedBufAddr && "Unexpected buffer"); + initMergedBuffer(opHasSubRegion, mergedBufAddr, mergedBufUpper, + nextSlotIndex); + } + } +} + +void MemAllocator::initRegionBuffers(Region *region, const Liveness &LN, + uint64_t &nextSlotIndex) { + // Walk through operations to identify and map buffers + std::shared_ptr buffer; + for (auto &op : region->getOps()) { + if (isSubkernelBufferOp(&op)) { + for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { + auto result = op.getResult(i); + auto scope = MemScope( + cast(op.getAttrOfType( + ev::MEMSCOPE)[op.getNumOperands() + i]) + .getInt()); + if (scope != memScope) + continue; + int64_t address = -1; + if (op.hasAttr(ev::addrName)) { + address = + cast(op.getAttrOfType(ev::addrName)[i]) + .getInt(); + } + buffer = std::make_shared( + result, scope, address < 0 ? 0 : address, nextSlotIndex++, + LN.resolveLiveness(result)); + if (address < 0) { + virtualBufs.insert(std::move(buffer)); + } else { + insertPhyBuf(buffer); + } + } + } else if (isMemoryAllocOp(&op) && getMemScope(&op) == memScope) { + int64_t address = -1; + if (op.hasAttr(ev::phyAddrName)) { + address = op.getAttrOfType(ev::phyAddrName).getInt(); + } + buffer = std::make_shared( + op.getResult(0), memScope, address < 0 ? 0 : address, nextSlotIndex++, + LN.resolveLiveness(op.getResult(0))); + if (address < 0) { + virtualBufs.insert(std::move(buffer)); + } else { + insertPhyBuf(buffer); + } + } else { + continue; + } + } +} + +void MemAllocator::init(Region *region, const Liveness &LN) { + // Collect all wanted operators in mentioned region: + // 1) 'isMemoryAllocOp' with memory space equal to 'memScope', + // but no attribute 'phyAddr' attached. + // 2) 'isMemoryAllocOp' with memory space equal to 'memScope', + // and has attribute 'phyAddr'. + // 3) operation has sub regions, which is necessary for liveness analysis + uint64_t nextSlotIndex = 0; + initRegionBuffers(region, LN, nextSlotIndex); + + // Init virtual buffers with liveness + // initVirtualBufs(memBuffers, LN, nextSlotIndex); + // Init external buffers. + if (auto externalLives = region->getParentOp()->getAttrOfType( + LiveBufString(memScope))) + initExternelBufs(externalLives, nextSlotIndex); + // Init fixed physical buffers. + // initCurrRegionPhyBufs(opsWithFixedAddr, LN, nextSlotIndex); + // initSubRegionPhyBufs(opsHasSubRegion, nextSlotIndex); + assert(virtualBufs.size() + phyBufsOrder.size() == nextSlotIndex && + "Error buffer size after init"); +} + +void MemAllocator::rewrite() { + + for (const auto &phyBuffer : getAllocResult()) { + if (phyBuffer->isExternal() || phyBuffer->isSubRegionBuf()) + continue; + Operation *alloc = phyBuffer->getOperation(); + Builder builder(alloc); + auto setAddr = [&builder](Operation *op, int64_t addr) { + auto phyAddrAttr = builder.getI64IntegerAttr(addr); + op->setAttr(phyAddrName, phyAddrAttr); + }; + if (isPreview() && !alloc->hasAttr(phyAddrName)) + alloc->setAttr(previewName, builder.getBoolAttr(true)); + if (isMemoryAllocOp(alloc)) { + if (phyBuffer->isOverflow()) { + auto sizeAttr = builder.getI64IntegerAttr(phyBuffer->size()); + alloc->setAttr(overflowName, sizeAttr); + setAddr(alloc, -1); + } else { + setAddr(alloc, phyBuffer->fixedAddr()); + } + } else { + assert(isSubkernelBufferOp(alloc) && + "Only support subkernel buffer or memory allocation"); + auto bufValue = phyBuffer->value(); + auto outIdx = cast(bufValue).getResultNumber(); + // Update the address at the specific index + if (phyBuffer->isOverflow()) { + alloc->setAttr(overflowName, + builder.getI64IntegerAttr(phyBuffer->size())); + setAddrAtIndex(alloc, outIdx, -1); + } else { + setAddrAtIndex(alloc, outIdx, phyBuffer->fixedAddr()); + } + } + } +} + +void MemAllocator::reset() { + virtualBufs.clear(); + phyBufsOrder.clear(); + phyBufsReverse.clear(); +} + +//===----------------------------------------------------------------------===// +// Order Memory Allocator +//===----------------------------------------------------------------------===// + +void OrderMemAllocator::allocate() { + while (!isVirtBufEmpty()) { + auto visitedBuffer = pickNextBuffer(); + assignAddrOrder(visitedBuffer, getAlign()); + if (isPreview() || !visitedBuffer->isOverflow()) + continue; + assert(false && "TODO: support buffer spill"); + } +} + +//===----------------------------------------------------------------------===// +// Dual-Directional Memory Allocator +//===----------------------------------------------------------------------===// + +void DualMemAllocator::allocate() { + while (!isVirtBufEmpty()) { + auto visitedBuffer = pickNextBuffer(); + // allocate ddr buffer in reverse order for now to avoid collision with input and output ddr buffers + if (visitedBuffer->isPreferBankAlone() || visitedBuffer->scope() == MemScope::DDR) { + assignAddrReverse(visitedBuffer, getAlign()); + } else { + assignAddrOrder(visitedBuffer, getAlign()); + } + if (isPreview() || !visitedBuffer->isOverflow()) + continue; + assert(false && "TODO: support buffer spill"); + } +} + +//===----------------------------------------------------------------------===// +// Bank Optimization Memory Allocator +//===----------------------------------------------------------------------===// + +void BankOptAllocator::allocate() { + auto toReallocBufs = LiveBufferSet(LayoutOrder); + + for (const auto &phyBuffer : getAllocResult()) { + assert((isPreview() || !phyBuffer->isOverflow()) && "Unsupported Spill"); + if (phyBuffer->isExternal() || phyBuffer->isSubRegionBuf()) + continue; + // FIXME: Maybe buffer specified by user has priority ? + if (phyBuffer->isPreferBankAlone() && !phyBuffer->isOverflow()) + toReallocBufs.insert(phyBuffer); + } + + assert(isVirtBufEmpty() && "All buffer must have been assigned"); + + while (!toReallocBufs.empty()) { + // 1. Revert a buffer with least physical address + const auto &reallocBuf = *toReallocBufs.begin(); + int64_t oriPhyAddr = reallocBuf->addr(); + toReallocBufs.erase(reallocBuf); + revertPhyBuf(reallocBuf); + // 2. Reassign physical address + auto visitedBuffer = pickNextBuffer(); + assert(visitedBuffer->getOperation() == reallocBuf->getOperation() && + "Next buffer must be reallocBuf"); + assignAddrOrder(visitedBuffer, getAlign()); + if (visitedBuffer->isOverflow()) { + erasePhyBuf(visitedBuffer); + visitedBuffer->setPhyAddr(oriPhyAddr); + insertPhyBuf(visitedBuffer); + } + } +} + +void BankOptAllocator::rewrite() { + for (const auto &phyBuffer : getAllocResult()) { + if (phyBuffer->isExternal() || phyBuffer->isSubRegionBuf()) + continue; + if (!phyBuffer->isPreferBankAlone() || phyBuffer->isOverflow()) + continue; + Operation *alloc = phyBuffer->getOperation(); + if (isMemoryAllocOp(alloc)) { + Builder builder(alloc); + auto phyAddrAttr = builder.getI64IntegerAttr(phyBuffer->fixedAddr()); + alloc->setAttr(phyAddrName, phyAddrAttr); + } else { + assert(isSubkernelBufferOp(alloc) && + "Only support subkernel buffer or memory allocation"); + auto bufValue = phyBuffer->value(); + auto outIdx = cast(bufValue).getResultNumber(); + // Update the address at the specific index + setAddrAtIndex(alloc, outIdx, phyBuffer->fixedAddr()); + } + } +} + +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/RemoveLoopIterArgsWithMemrefType.cpp b/third_party/evas/lib/Transform/Linalg/RemoveLoopIterArgsWithMemrefType.cpp new file mode 100644 index 0000000000..e4e5ec0ead --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/RemoveLoopIterArgsWithMemrefType.cpp @@ -0,0 +1,247 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "llvm/ADT/SmallVector.h" +#define GEN_PASS_DEF_REMOVELOOPITERARGSWITHMEMREFTYPE +#include "evas/Transform/Linalg/Passes.h.inc" +namespace mlir::triton::ev { + +namespace { +/// A pass to insert deallocations for allocated buffers after theirlast use. +using namespace mlir; + +scf::ForOp isInForLoop(Operation *op) { + // Walk up the parent operations to find a scf::ForOp + Operation *currentOp = op; + while (currentOp) { + if (auto forOp = dyn_cast(currentOp)) { + return forOp; + } + currentOp = currentOp->getParentOp(); + } + return nullptr; +} + +Value findInitArgFromIterArg(Value iterArg, scf::ForOp forOp) { + // Get the position of iterArg in the loop's region iter args + int position = -1; + for (auto [idx, arg] : llvm::enumerate(forOp.getRegionIterArgs())) { + if (arg == iterArg) { + position = idx; + break; + } + } + + // If iterArg not found in region iter args, return nullptr + if (position == -1) + return nullptr; + + // Return the corresponding init arg at the same position + return forOp.getInitArgs()[position]; +} + +LogicalResult replaceMemRefIterArgs(scf::ForOp forOp, BufferOriginAnalysis &bufferOriginAnalysis) { + // Track which init args are memref type and have different yield values + llvm::SmallVector shouldReplace(forOp.getNumRegionIterArgs(), false); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + // Check each init arg and corresponding yield value + for (auto [i, initAndYield] : + llvm::enumerate(llvm::zip(forOp.getInitArgs(), yieldOp.getOperands()))) { + auto [initArg, yieldVal] = initAndYield; + + // Check if init arg is memref type + if (isa(initArg.getType())) { + // Check if yield value is different from init arg + shouldReplace[i] = !bufferOriginAnalysis.isSameAllocation(initArg, yieldVal).value(); + } + } + + // If no replacement needed, return failure + if (llvm::none_of(shouldReplace, [](bool replace) { return replace; })) { + return failure(); + } + + // Replace uses of yield values with iter args where needed + for (auto [i, vals] : llvm::enumerate(llvm::zip(yieldOp.getOperands(), + forOp.getRegionIterArgs(), + forOp.getInitArgs()))) { + if (shouldReplace[i]) { + auto [yieldVal, iterArg, initArg] = vals; + // Replace all uses of yield value with iter arg within the loop body + yieldVal.replaceAllUsesWith(initArg); + iterArg.replaceAllUsesWith(initArg); + } + } + return success(); +} + +LogicalResult removeUnusedIterArgs(scf::ForOp forOp) { + // Track which iter args are used + llvm::SmallVector isIterArgUsed(forOp.getNumRegionIterArgs(), false); + + // Check usage of each iter arg in the loop body + forOp.getBody()->walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + for (auto [idx, iterArg] : llvm::enumerate(forOp.getRegionIterArgs())) { + if (operand == iterArg) { + isIterArgUsed[idx] = true; + } + } + } + }); + + // If all iter args are used, nothing to do + if (llvm::all_of(isIterArgUsed, [](bool used) { return used; })) { + return failure(); + } + + // Collect used iter args, init args and yield operands + SmallVector newInitArgs; + SmallVector newIterArgs; + SmallVector newYieldOperands; + SmallVector newOutputs; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + + for (auto i : llvm::seq(0, forOp.getNumRegionIterArgs())) { + if (isIterArgUsed[i]) { + newInitArgs.push_back(forOp.getInitArgs()[i]); + newIterArgs.push_back(forOp.getRegionIterArgs()[i]); + newYieldOperands.push_back(yieldOp.getOperands()[i]); + newOutputs.push_back(forOp->getResult(i)); + } else { + forOp->getResult(i).replaceAllUsesWith(forOp.getInitArgs()[i]); + } + } + + // Create new ForOp with only used iter args + OpBuilder builder(forOp); + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + // Clone the body of old ForOp into new ForOp using IRMapping + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (auto [oldArg, newArg] : + llvm::zip(newIterArgs, newForOp.getRegionIterArgs())) { + mapping.map(oldArg, newArg); + } + + // Remove the automatically created terminator if it exists + Block *newBody = newForOp.getBody(); + if (!newBody->empty() && isa(newBody->getTerminator())) { + newBody->back().erase(); + } + + builder.setInsertionPointToStart(newBody); + + // Clone all operations except the terminator + for (auto &op : forOp.getBody()->without_terminator()) { + builder.clone(op, mapping); + } + + // Update yield operands by mapping them through IRMapping + SmallVector mappedYieldOperands; + for (Value yieldOp : newYieldOperands) { + mappedYieldOperands.push_back(mapping.lookupOrDefault(yieldOp)); + } + builder.create(forOp.getLoc(), mappedYieldOperands); + + // Replace old ForOp with new one + for (auto [oldResult, newOutput] : + llvm::zip(newOutputs, newForOp->getResults())) { + oldResult.replaceAllUsesWith(newOutput); + } + forOp->erase(); + + return success(); +} + +bool CanonicalizeMemRefIterArgs(scf::ForOp forOp, BufferOriginAnalysis &bufferOriginAnalysis) { + return succeeded(replaceMemRefIterArgs(forOp, bufferOriginAnalysis)) && + succeeded(removeUnusedIterArgs(forOp)); +} + +// class LinalgAddInplacePattern : public OpRewritePattern { +// public: +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(linalg::AddOp op, +// PatternRewriter &rewriter) const override { + +// Value lhs = op.getOperand(0); +// Value rhs = op.getOperand(1); +// Value dst_output = op.getOperand(2); +// Value output = op.getResult(0); + +// if (lhs == dst_output || rhs == dst_output) +// return failure(); + +// // Check if value is in the iter_args of a scf.for loop +// auto isInIterArgs = [](Value value, scf::ForOp forOp) { +// for (Value iterArg : forOp.getRegionIterArgs()) { +// if (iterArg == value) +// return true; +// } +// return false; +// }; + +// auto forOp = isInForLoop(op); +// if (!forOp) +// return failure(); + +// // Create a new linalg op with the iter_args input replacing dst_output +// SmallVector outOperands; +// SmallVector inputOperands; +// if (isInIterArgs(rhs, forOp)) { +// outOperands.push_back(findInitArgFromIterArg( +// rhs, forOp)); // Use rhs as both input and output +// inputOperands = {lhs, outOperands[0]}; +// } else if (isInIterArgs(lhs, forOp)) { +// outOperands.push_back(findInitArgFromIterArg( +// lhs, forOp)); // Use lhs as both input and output +// inputOperands = {outOperands[0], rhs}; +// } else { +// return failure(); +// } + +// auto newOp = rewriter.create( +// op.getLoc(), op.getResultTypes(), inputOperands, outOperands, +// linalg::getPrunedAttributeList(op)); + +// rewriter.replaceOp(op, newOp.getResults()); +// return success(); +// } +// }; +struct RemoveLoopIterArgsWithMemrefTypePass + : public ::impl::RemoveLoopIterArgsWithMemrefTypeBase { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + OpBuilder builder(moduleOp.getContext()); + BufferOriginAnalysis bufferOriginAnalysis(moduleOp); + // RewritePatternSet patterns(&getContext()); + // patterns.add(&getContext()); + // (void)mlir::applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); + moduleOp.walk([&](scf::ForOp forOp) { + if (CanonicalizeMemRefIterArgs(forOp, bufferOriginAnalysis)) { + return WalkResult::advance(); + } + return WalkResult::skip(); + }); + } +}; + +} // namespace +std::unique_ptr createRemoveLoopIterArgsWithMemrefTypePass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/RemoveRedundencyCopyPass.cpp b/third_party/evas/lib/Transform/Linalg/RemoveRedundencyCopyPass.cpp new file mode 100644 index 0000000000..bd9cb659b8 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/RemoveRedundencyCopyPass.cpp @@ -0,0 +1,215 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) 2024 The EVAS Intelligence Inc. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//============================================================================== + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define GEN_PASS_DEF_REMOVEREDUNDENCYCOPY +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { + +namespace { + +// Pattern 1: Eliminate self-copy: copy(A, A) -> no-op +struct EliminateSelfCopyPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + Value source = copyOp.getSource(); + Value target = copyOp.getTarget(); + + // Check if source and target are the same value + if (source == target) { + rewriter.eraseOp(copyOp); + return success(); + } + + return failure(); + } +}; + +// Pattern 2: Chain copy elimination: copy(A, B) + copy(B, C) -> copy(A, C) +struct ChainCopyEliminationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + Value source = copyOp.getSource(); + Value target = copyOp.getTarget(); + + for (Operation *user : target.getUsers()) { + if (user == copyOp) + continue; + + auto nextCopy = dyn_cast(user); + if (!nextCopy) + continue; + + // Check if nextCopy uses target as source: copy(B, C) + if (nextCopy.getSource() == target) { + // Check dominance: copyOp must dominate nextCopy + DominanceInfo domInfo(copyOp->getParentOfType()); + if (!domInfo.dominates(copyOp.getOperation(), nextCopy.getOperation())) + continue; + + // Check if target is only used by these two copy operations + // (or we need to be more careful about aliasing) + bool hasUserBetween = false; + bool hasUserAfter = false; + for (Operation *targetUser : target.getUsers()) { + if (targetUser != copyOp && targetUser != nextCopy) { + if (copyOp->isBeforeInBlock(targetUser) && + targetUser->isBeforeInBlock(nextCopy)) { + hasUserBetween = true; + break; + } else { + hasUserAfter = true; + } + } + } + if (hasUserBetween) + continue; + // Create new copy: copy(A, C) + Value finalTarget = nextCopy.getTarget(); + rewriter.setInsertionPoint(nextCopy); + rewriter.create(copyOp.getLoc(), source, finalTarget); + // If the target is only used after the copy, and the types are the + // same, we can replace the target with the final target + if (hasUserAfter && target.getType() == finalTarget.getType()) { + rewriter.replaceAllUsesWith(target, finalTarget); + } + rewriter.eraseOp(nextCopy); + rewriter.eraseOp(copyOp); + return success(); + } + } + + return failure(); + } +}; + +// Pattern 3: Eliminate redundant copy: copy(A, B) followed by another copy(A, +// B) +struct EliminateRedundantCopyPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + Value source = copyOp.getSource(); + Value target = copyOp.getTarget(); + + // Look for another copy operation with the same source and target + for (Operation *user : source.getUsers()) { + if (user == copyOp) + continue; + + auto otherCopy = dyn_cast(user); + if (!otherCopy) + continue; + + // Check if it's the same copy: copy(A, B) + if (otherCopy.getSource() == source && otherCopy.getTarget() == target) { + // Check dominance: one must dominate the other + DominanceInfo domInfo(copyOp->getParentOfType()); + bool otherDominates = + domInfo.dominates(otherCopy.getOperation(), copyOp.getOperation()); + bool thisDominates = + domInfo.dominates(copyOp.getOperation(), otherCopy.getOperation()); + + if (!otherDominates && !thisDominates) + continue; + + Operation *firstCopy = + otherDominates ? otherCopy.getOperation() : copyOp.getOperation(); + Operation *secondCopy = + otherDominates ? copyOp.getOperation() : otherCopy.getOperation(); + + // Simple check: if there are any writes to source or target between + // the two copies, we can't eliminate + bool hasWriteBetween = false; + for (Operation &op : + llvm::make_range(std::next(firstCopy->getIterator()), + secondCopy->getIterator())) { + if (auto storeOp = dyn_cast(&op)) { + if (storeOp.getMemRef() == source || + storeOp.getMemRef() == target) { + hasWriteBetween = true; + break; + } + } + if (auto otherCopyOp = dyn_cast(&op)) { + if (otherCopyOp.getTarget() == source || + otherCopyOp.getTarget() == target) { + hasWriteBetween = true; + break; + } + } + } + + if (hasWriteBetween) + continue; + + // Eliminate the second copy + rewriter.eraseOp(secondCopy); + return success(); + } + } + + return failure(); + } +}; + +struct RemoveRedundencyCopyPass + : public ::impl::RemoveRedundencyCopyBase { + using RemoveRedundencyCopyBase< + RemoveRedundencyCopyPass>::RemoveRedundencyCopyBase; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + MLIRContext &context = getContext(); + + RewritePatternSet patterns(&context); + patterns.add(&context); + patterns.add(&context); + patterns.add(&context); + + // Apply patterns with multiple iterations to handle cascading optimizations + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createRemoveRedundencyCopyPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/RemoveScalar.cpp b/third_party/evas/lib/Transform/Linalg/RemoveScalar.cpp new file mode 100644 index 0000000000..fd24507a6a --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/RemoveScalar.cpp @@ -0,0 +1,154 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define GEN_PASS_DECL_REMOVESCALAR +#define GEN_PASS_DEF_REMOVESCALAR +#include "evas/Transform/Linalg/Passes.h.inc" + +#define DEBUG_TYPE "convert-to-2d" + +/* + description: Eliminate the scalars generated by the reduce operation in the + scalar-matrix multiplication process + + example : + before : + %extracted_11 = tensor.extract %reduced_9[%c0_10] : tensor<1xf32> + %16 = tensor.empty() : tensor<3x2xf32> + %17 = linalg.fill ins(%extracted_11 : f32) outs(%16 : tensor<3x2xf32>) -> + tensor<3x2xf32> + + after: + %7 = tensor.empty() : tensor<3x2x1xf32> + %broadcasted = linalg.broadcast ins(%reduced : tensor<1xf32>) outs(%7 : + tensor<3x2x1xf32>) dimensions = [0, 1] %collapsed_0 = tensor.collapse_shape + %broadcasted [[0], [1, 2]] : tensor<3x2x1xf32> into tensor<3x2xf32> + + special case : + before : + %extracted = tensor.extract %reduced[%c0] : tensor<1xf32> + %6 = tensor.empty() : tensor<1xf32> + %7 = linalg.fill ins(%extracted : f32) outs(%6 : tensor<1xf32>) -> + tensor<1xf32> + + after: (Completely eliminated) + +*/ + +namespace mlir::triton::ev { + +namespace { +using namespace mlir; + +struct RemoveScalarPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value extractResult = op.getResult(); + if (!extractResult.hasOneUse()) { + return failure(); + } + + Operation *nextOp = *extractResult.getUsers().begin(); + auto fillOp = dyn_cast(nextOp); + if (!fillOp) { + return failure(); + } + + Value sourceTensor = op.getTensor(); + auto destType = cast(fillOp.getOutputs()[0].getType()); + + std::vector emptyShape = destType.getShape().vec(); + emptyShape.push_back(1); + + bool isEmptyShapeAllOnes = true; + for (int64_t dim : emptyShape) { + if (dim != 1) { + isEmptyShapeAllOnes = false; + break; + } + } + + if (isEmptyShapeAllOnes) { + SmallVector reassociation; + ReassociationIndices expandIndices; + for (int i = 0; i < destType.getRank(); i++) { + expandIndices.push_back(i); + } + reassociation.push_back(expandIndices); + + auto expanded = rewriter.create( + loc, destType, sourceTensor, reassociation); + + rewriter.replaceOp(fillOp, expanded.getResult()); + rewriter.eraseOp(op); + } else { + Value empty = rewriter.create(loc, emptyShape, + destType.getElementType()); + + SmallVector broadcastDims; + for (int64_t i = 0; i < destType.getRank(); i++) { + broadcastDims.push_back(i); + } + + auto broadcastOp = rewriter.create( + loc, sourceTensor, empty, broadcastDims); + + SmallVector reassociation; + for (int i = 0; i < destType.getRank(); i++) { + reassociation.push_back({i}); + } + reassociation.back().push_back(destType.getRank()); + + auto collapsedOp = rewriter.create( + loc, fillOp.getOutputs()[0].getType(), broadcastOp.getResult()[0], + reassociation); + + rewriter.replaceOp(fillOp, collapsedOp.getResult()); + rewriter.eraseOp(op); + if (Operation *emptyOp = fillOp.getOutputs()[0].getDefiningOp()) { + rewriter.eraseOp(emptyOp); + } + } + + return success(); + } + +private: +}; + +struct RemoveScalarPass : public ::impl::RemoveScalarBase { + using RemoveScalarBase::RemoveScalarBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + mlir::ModuleOp moduleOp = getOperation(); + MLIRContext &context = getContext(); + + RewritePatternSet patterns(&context); + patterns.add(&context); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace +std::unique_ptr createRemoveScalarPass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev \ No newline at end of file diff --git a/third_party/evas/lib/Transform/Linalg/RewriteDataType.cpp b/third_party/evas/lib/Transform/Linalg/RewriteDataType.cpp new file mode 100644 index 0000000000..dc0bb41614 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/RewriteDataType.cpp @@ -0,0 +1,329 @@ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "evas/Transform/Linalg/Passes.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/ADT/STLExtras.h" +#include +#define GEN_PASS_DEF_REWRITEDATATYPE +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { +using namespace mlir; + +#define DEBUG_TYPE "ev-rewrite-data-type" + +namespace { + +bool isScalar(Value val) { + return !isa(val.getType()); +} + +bool isScalarOp(Operation *op) { + auto isScalarValue = [](Value val) { return isScalar(val); }; + return llvm::all_of(op->getOperands(), isScalarValue) && + llvm::all_of(op->getResults(), isScalarValue); +} + +bool isInsideLinalgRegion(Operation *op) { + for (Operation *parent = op->getParentOp(); parent; + parent = parent->getParentOp()) { + if (auto *dialect = parent->getDialect(); + dialect && dialect->getNamespace() == "linalg") + return true; + } + return false; +} + +/// Check if all i64 values in dense attr are within i32 range +bool allValuesInI32Range(DenseIntElementsAttr attr) { + if (!attr.getElementType().isInteger(64)) + return false; + return llvm::all_of(attr.getValues(), [](const APInt &val) { + int64_t v = val.getSExtValue(); + return v >= INT32_MIN && v <= INT32_MAX; + }); +} + +bool isI64ConstantInI32Range(arith::ConstantOp op) { + auto value = op.getValue(); + if (auto intAttr = dyn_cast(value)) { + if (intAttr.getType().isInteger(64)) { + int64_t val = intAttr.getValue().getSExtValue(); + return val >= INT32_MIN && val <= INT32_MAX; + } + } + if (auto denseAttr = dyn_cast(value)) + return allValuesInI32Range(denseAttr); + return false; +} + +/// Create integer cast operation based on width comparison +Value createIntegerCast(OpBuilder &builder, Location loc, Type resultType, + Value input, bool useSignedExt) { + auto srcIntType = dyn_cast(input.getType()); + auto dstIntType = dyn_cast(resultType); + if (!srcIntType || !dstIntType) + return nullptr; + + unsigned srcWidth = srcIntType.getWidth(); + unsigned dstWidth = dstIntType.getWidth(); + if (srcWidth > dstWidth) + return builder.create(loc, resultType, input).getResult(); + if (srcWidth < dstWidth) { + return useSignedExt ? builder.create(loc, resultType, input) + .getResult() + : builder.create(loc, resultType, input) + .getResult(); + } + return nullptr; +} + +/// Extract integer values from dense attr and convert to target type +template SmallVector extractValues(DenseIntElementsAttr attr) { + SmallVector result; + for (const APInt &val : attr.getValues()) + result.push_back(static_cast(val.getSExtValue())); + return result; +} + +//===----------------------------------------------------------------------===// +// Type Converter +//===----------------------------------------------------------------------===// + +/// Converts i1 -> i8 and i64 -> i32, including nested types in tensors/memrefs +class DataTypeConverter : public TypeConverter { +public: + DataTypeConverter() { + addConversion([](Type type) { return type; }); + + addConversion([](IntegerType intType) -> Type { + unsigned width = intType.getWidth(); + if (width == 1) + return IntegerType::get(intType.getContext(), 8); + if (width == 64) + return IntegerType::get(intType.getContext(), 32); + return intType; + }); + + addConversion([this](TensorType type) -> Type { + return convertElementType(type, + [](auto t, auto e) { return t.clone(e); }); + }); + + addConversion([this](BaseMemRefType type) -> Type { + return convertElementType( + type, [](auto t, auto e) { return t.cloneWith(std::nullopt, e); }); + }); + + addConversion([this](MemRefType type) -> Type { + return convertElementType( + type, [](auto t, auto e) { return t.cloneWith(t.getShape(), e); }); + }); + + addTargetMaterialization([](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return materialize(builder, resultType, inputs, loc, true); + }); + + addSourceMaterialization([](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return materialize(builder, resultType, inputs, loc, true); + }); + } + +private: + template + Type convertElementType(T type, CloneFn cloneFn) { + Type elemType = type.getElementType(); + Type converted = convertType(elemType); + return (converted != elemType) ? cloneFn(type, converted) : type; + } + + static Value materialize(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc, bool useSignedExt) { + if (inputs.size() != 1) + return nullptr; + if (Value cast = createIntegerCast(builder, loc, resultType, inputs[0], + useSignedExt)) { + return cast; + } + return builder.create(loc, resultType, inputs) + .getResult(0); + } +}; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Converts arith.constant op - handles both value attribute and result type +struct ConstantOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = op.getResult().getType(); + Type convertedType = getTypeConverter()->convertType(resultType); + if (convertedType == resultType) + return failure(); + + Attribute value = op.getValue(); + + // Handle dense tensor constants + if (auto denseAttr = dyn_cast(value)) { + auto newAttr = convertDenseAttr(denseAttr, convertedType); + if (!newAttr) + return failure(); + rewriter.replaceOpWithNewOp(op, convertedType, + newAttr); + return success(); + } + + // Handle scalar integer constants (i64 -> i32) + if (auto intAttr = dyn_cast(value)) { + auto newAttr = + IntegerAttr::get(convertedType, intAttr.getValue().trunc(32)); + rewriter.replaceOpWithNewOp(op, convertedType, + newAttr); + return success(); + } + + return failure(); + } + +private: + /// Convert dense integer attr to new shaped type + static DenseIntElementsAttr convertDenseAttr(DenseIntElementsAttr attr, + Type convertedType) { + auto shapedType = cast(convertedType); + Type elemType = attr.getElementType(); + + // i64 -> i32 conversion + if (elemType.isInteger(64)) { + auto values = extractValues(attr); + return DenseIntElementsAttr::get(shapedType, values); + } + // i1 -> i8 conversion + if (elemType.isInteger(1)) { + auto values = extractValues(attr); + return DenseIntElementsAttr::get(shapedType, values); + } + return nullptr; + } +}; + +/// Generic converter for all ops - clones op with converted types +struct GenericOpConverter : public ConversionPattern { + GenericOpConverter(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector convertedTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + convertedTypes))) + return failure(); + + OperationState state(op->getLoc(), op->getName()); + state.addOperands(operands); + state.addTypes(convertedTypes); + state.addAttributes(op->getAttrs()); + state.addSuccessors(op->getSuccessors()); + for (size_t i = 0; i < op->getNumRegions(); ++i) + state.addRegion(); + + Operation *newOp = rewriter.create(state); + inlineRegionsAndConvertBlockArgs(op, newOp, rewriter); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + +private: + void + inlineRegionsAndConvertBlockArgs(Operation *srcOp, Operation *dstOp, + ConversionPatternRewriter &rewriter) const { + for (auto [src, dst] : + llvm::zip(srcOp->getRegions(), dstOp->getRegions())) { + rewriter.inlineRegionBefore(src, dst, dst.end()); + for (Block &block : dst) { + for (BlockArgument arg : block.getArguments()) { + Type converted = getTypeConverter()->convertType(arg.getType()); + if (converted != arg.getType()) + arg.setType(converted); + } + } + } + } +}; + +class RewriteDataTypePass + : public ::impl::RewriteDataTypeBase { + using RewriteDataTypeBase::RewriteDataTypeBase; + +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + DataTypeConverter typeConverter; + + ConversionTarget target(*context); + + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + // arith.constant with i64 values in i32 range should be converted + if (auto constOp = dyn_cast(op)) { + if (isI64ConstantInI32Range(constOp)) { + return typeConverter.isLegal(op->getResultTypes()); + } + } + + // Scalar ops outside linalg regions are legal (not converted) + if (isScalarOp(op) && !isInsideLinalgRegion(op)) { + return true; + } + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + + target.addLegalOp(); + + RewritePatternSet patterns(context); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(m, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + + // Run canonicalizer to clean up unrealized conversion casts + mlir::PassManager pm(m.getContext()); + pm.addPass(mlir::createCanonicalizerPass()); + (void)pm.run(m); + } +}; + +} // namespace + +std::unique_ptr createRewriteDataTypePass() { + return std::make_unique(); +} + +} // namespace mlir::triton::ev + diff --git a/third_party/evas/lib/Transform/Linalg/RewriteFuncOpArgsType.cpp b/third_party/evas/lib/Transform/Linalg/RewriteFuncOpArgsType.cpp new file mode 100644 index 0000000000..82d5857838 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/RewriteFuncOpArgsType.cpp @@ -0,0 +1,103 @@ +#include "epu/memory.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "triton-shared/Transform/common_utils.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +#define GEN_PASS_DEF_REWRITEFUNCOPARGSTYPE +#include "evas/Transform/Linalg/Passes.h.inc" +namespace mlir::triton::ev { + +namespace { +/// A pass to insert deallocations for allocated buffers after theirlast use. +using namespace mlir; + +struct RewriteFuncOpArgsTypePass + : public ::impl::RewriteFuncOpArgsTypeBase { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + OpBuilder builder(moduleOp.getContext()); + SymbolTable symTable(moduleOp); + + // Walk through all AnnotateOp operations in the module + moduleOp.walk([&](func::CallOp op) { + auto funcOp = utils::getCalledFunction(op); + // Check if all operands come from to_tensor ops + SmallVector newOperands; + SmallVector newTypes; + + for (auto operand : op.getOperands()) { + if (auto defOp = operand.getDefiningOp()) { + if (auto toTensorOp = dyn_cast(defOp)) { + newOperands.push_back(toTensorOp.getOperand()); + newTypes.push_back(toTensorOp.getOperand().getType()); + } else { + op->emitError("operands must all come from to_tensor ops"); + return; + } + } + } + + // Create new function type with memref types and no results + auto newFuncType = builder.getFunctionType(newTypes, funcOp.getResultTypes()); + builder.setInsertionPoint(op->getParentOfType()); + // Construct new function name by appending "_memref" to original name + auto newFuncName = funcOp.getName().str() + "_memref"; + // Create new function with updated type + auto newFuncOp = builder.create( + funcOp.getLoc(), newFuncName, newFuncType, + funcOp.getSymVisibilityAttr(), funcOp.getArgAttrsAttr(), + funcOp.getResAttrsAttr()); + newFuncOp->setAttrs(funcOp->getDiscardableAttrDictionary()); + symTable.insert(newFuncOp); + // Copy function body if exists + if (!funcOp.empty()) { + Block &oldEntryBlock = funcOp.getBody().front(); + Block *newEntryBlock = newFuncOp.addEntryBlock(); + + // Insert to_tensor ops at the beginning of the new function + builder.setInsertionPointToStart(newEntryBlock); + SmallVector tensorArgs; + for (auto arg : newFuncOp.getArguments()) { + auto toTensor = + builder.create(arg.getLoc(), + memref::getTensorTypeFromMemRefType(arg.getType()), + arg)->getResult(0); + tensorArgs.push_back(toTensor); + } + + // Clone the rest of function body + IRMapping mapper; + for (auto [oldArg, newArg] : + llvm::zip(oldEntryBlock.getArguments(), tensorArgs)) { + mapper.map(oldArg, newArg); + } + for (auto &op : oldEntryBlock) { + builder.clone(op, mapper); + } + } + // Replace old function with new one + funcOp.erase(); + // Create new call op with memref operands and no results + builder.setInsertionPoint(op); + auto newCall = + builder.create(op.getLoc(), newFuncOp, newOperands); + newCall->setAttrs(op->getDiscardableAttrDictionary()); + // Remove old call op since it's no longer needed + op.replaceAllUsesWith(newCall); + op.erase(); + }); + } +}; + +} // namespace +std::unique_ptr createRewriteFuncOpArgsTypePass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/SetDeviceInfo.cpp b/third_party/evas/lib/Transform/Linalg/SetDeviceInfo.cpp new file mode 100644 index 0000000000..7f359536f2 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/SetDeviceInfo.cpp @@ -0,0 +1,302 @@ +#include "epu/memory.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#define GEN_PASS_DEF_SETDEVICEINFO +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { + +namespace { +/// A pass to insert deallocations for allocated buffers after theirlast use. +using namespace mlir; +using namespace mlir::ev; +static constexpr llvm::StringRef ADDRESS = "address"; +static constexpr llvm::StringRef MEMSCOPE = "mem_scope"; +static constexpr llvm::StringRef COREBIND = "core_bind"; +struct SetDeviceInfoPass : public ::impl::SetDeviceInfoBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + auto funcOps = llvm::to_vector(module.getOps()); + if (!funcOps.empty()) { + auto funcOp = funcOps[0]; + funcOp.setSymName("kernel"); + } + for (auto func : funcOps) { + if (func.isDeclaration()) { + continue; + } + OpBuilder b(func); + // func.walk([&](bufferization::AllocTensorOp op) { SetScopeInfo(op, b); + // }); + func.walk([&](func::CallOp op) { + SetScopeInfo(op, b); + }); + // FixHoistedStoreToSubKernel(b); + } + } + +private: + llvm::DenseMap gScopeMap; + llvm::DenseMap hoistedStoreMap; + + // void SetMemrefScopeInfo(func::FuncOp funcOp) { + // funcOp.walk([&](memref::AllocOp op) { + + // }); + // } + + void FixHoistedStoreToSubKernel(OpBuilder b) { + for (auto &[val, op] : hoistedStoreMap) { + auto callOp = cast(val.getDefiningOp()); + auto storeOp = cast(op); + b.setInsertionPointAfter(storeOp); + auto toTensorOp = b.create( + storeOp->getLoc(), storeOp.getDest().getType(), storeOp.getSource()); + auto dstPtr = storeOp.getDest(); + dstPtr.replaceAllUsesWith(toTensorOp->getResult(0)); + storeOp->erase(); + auto addrAttr = getAddress(dstPtr); + if (addrAttr) { + callOp->setAttr(ADDRESS, addrAttr); + } + } + } + + Attribute getAddress(Value memref) { + auto op = memref.getDefiningOp(); + if (isa(op)) { + if (op->hasAttr(mlir::ev::phyAddrName)) { + return op->getAttr(mlir::ev::phyAddrName); + } + return nullptr; + } + if (isa(op)) + return getAddress(op->getOperand(0)); + llvm_unreachable("unexpected ops when traverse up ptr"); + return nullptr; + } + + MemScope getMemScope(Value val) { + if (isa(val.getType())) { + return mlir::ev::getMemScope(val.getType()); + } + if (gScopeMap.contains(val)) + return gScopeMap[val]; + + auto op = val.getDefiningOp(); + if (!op) + return mlir::ev::MemScope::DDR; + + if (isa(op)) { + MemRefType memType = cast(op).getType(); + return mlir::ev::getMemScope(memType); + } else if (isa(op)) { + auto memSpace = cast(op).getMemorySpace(); + assert(memSpace && isa(*memSpace) && + "Memscope is not allowed to be empty"); + return static_cast(cast(*memSpace).getInt()); + } else if (isa(op)) { + return mlir::ev::getMemScope( + cast(op).getOperand()); + } else if (op->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) { + for (auto input : op->getOperands()) { + auto memScope = getMemScope(input); + if (memScope != MemScope::UNKNOWN) { + return memScope; + } + } + return MemScope::UNKNOWN; + } else if (isa(op)) { + return MemScope::DDR; + } else { + return MemScope::UNKNOWN; + } + } + + Operation *getOutputStoreCopy(Value val) { + for (auto useOp : val.getUsers()) { + if (isa(useOp)) { + return useOp; + } else if (useOp->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) { + return getOutputStoreCopy(useOp->getResult(0)); + } + } + return nullptr; + } + bool isFuncArgsPointer(Value ptr) { + auto op = ptr.getDefiningOp(); + if (!op) + return true; + if (isa(op)) { + return isFuncArgsPointer(op->getOperand(0)); + } + return false; + } + + // void SetScopeInfo(bufferization::AllocTensorOp op, OpBuilder b) { + // op.setMemorySpaceAttr(b.getI64IntegerAttr(mlir::ev::MemScope::L2)); + // } + + // void SetScopeInfo(memref::AllocOp op, OpBuilder b) { + // auto memrefType = op.getMemref().getType(); + // if(!memrefType.getMemorySpace()) { + // auto newType = MemRefType::get(memrefType.getShape(), + // memrefType.getElementType(), memrefType.getLayout(), + // b.getI64IntegerAttr(mlir::ev::MemScope::L2)); + // op.getMemref().setType(newType); + // } + // } + + void SetScopeInfo(func::CallOp op, OpBuilder b) { + std::vector memScopeArray; + for (auto input : op.getOperands()) { + if (!gScopeMap.contains(input)) { + auto scope = getMemScope(input); + gScopeMap[input] = scope; + } + memScopeArray.push_back(b.getI64IntegerAttr(gScopeMap[input])); + } + // for (auto output : op->getResults()) { + // if (!gScopeMap.contains(output)) { + // if (auto storeOp = getOutputStoreCopy(output)) { + // auto storeMemref = storeOp->getOperand(1); + // if (!isFuncArgsPointer(storeMemref)) { + // // gScopeMap[output] = mlir::ev::getMemScope(storeMemref); + // hoistedStoreMap[output] = storeOp; + // } else { + // gScopeMap[output] = mlir::ev::MemScope::L2; + // } + // } else { + // gScopeMap[output] = mlir::ev::MemScope::L2; + // } + // } + + // memScopeArray.push_back(b.getI64IntegerAttr(gScopeMap[output])); + // } + op->setAttr(MEMSCOPE, b.getArrayAttr(memScopeArray)); + } + + ArrayAttr getCoreBind(func::CallOp op, OpBuilder b) { + SymbolTable symbolTable(op->getParentOfType()); + func::FuncOp funcOp = symbolTable.lookup(op.getCallee()); + assert(funcOp); + SmallVector outputs; + funcOp.walk([&](func::ReturnOp returnOp) { + for (auto operand : returnOp->getOperands()) { + outputs.push_back(operand); + } + }); + SmallVector coreBindArray; + const size_t core_num = 4; + for (size_t i = 0; i < core_num; ++i) { + coreBindArray.push_back(b.getI64IntegerAttr(i)); + } + auto core_bind = [&](int idx) { + ArrayAttr ret = b.getArrayAttr(coreBindArray); + for (int i = 0; i < idx; i++) + ret = b.getArrayAttr(SmallVector(1, ret)); + return ret; + }; + if (outputs.size() == 1) { + auto output = outputs[0]; + if (isa_and_nonnull( + output.getDefiningOp())) { + return b.getArrayAttr( + SmallVector(1, b.getI64IntegerAttr(0))); + } + if (isa_and_nonnull( + output.getDefiningOp())) { + auto op = output.getDefiningOp(); + if (auto parallelizableAttr = + op->getAttrOfType("isParallelizable")) { + if (!parallelizableAttr.getValue()) { + return b.getArrayAttr( + SmallVector(1, b.getI64IntegerAttr(0))); + } + } + } + auto type = cast(output.getType()); + assert(type && type.hasRank()); + for (auto [idx, dim] : llvm::enumerate(type.getShape())) { + if (dim < core_num || dim % core_num != 0) { + if (idx == type.getRank() - 1) { + return b.getArrayAttr( + SmallVector(1, b.getI64IntegerAttr(0))); + } + } else if (auto reduceOp = dyn_cast_or_null( + output.getDefiningOp())) { + mlir::detail::DenseArrayAttrImpl dimension = + reduceOp.getDimensionsAttr(); + // dimension is not equal to the current idx + auto size = dimension.getSize(); + bool not_equal = true; + for (int i = 0; i < size; i++) { + if (dimension[i] == idx) { + not_equal = false; + break; + } + } + if (not_equal) + return core_bind(idx); + } else if (auto broadcastOp = dyn_cast_or_null( + output.getDefiningOp())) { + mlir::detail::DenseArrayAttrImpl dimension = + broadcastOp.getDimensionsAttr(); + // dimension is the last dimension + if (dimension[0] + 1 == type.getRank()) { + // Only the dimensions before dimension can be bound + if (idx < dimension[0]) + return core_bind(idx); + } else { + // Dimensions no larger than dimension can be bound + if (idx <= dimension[0]) { + return core_bind(idx); + } else { + return b.getArrayAttr( + SmallVector(1, b.getI64IntegerAttr(0))); + } + } + } else if (auto layernormOp = dyn_cast_or_null( + output.getDefiningOp())) { + auto dimension = layernormOp.getDimensionAttr(); + // dimension is not equal to the work idx + if (idx != dimension.getInt()) + return core_bind(idx); + } else { + // The current dimension is divisible by 4, and we don't have to + // determine dimension + return core_bind(idx); + } + } + } else { + return b.getArrayAttr(SmallVector(1, b.getI64IntegerAttr(0))); + } + return b.getArrayAttr(SmallVector(1, b.getI64IntegerAttr(0))); + } + +}; +} // namespace +std::unique_ptr createSetDeviceInfoPass() { + return std::make_unique(); +} +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/SetMemRefScopePass.cpp b/third_party/evas/lib/Transform/Linalg/SetMemRefScopePass.cpp new file mode 100644 index 0000000000..e5d9f1b12f --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/SetMemRefScopePass.cpp @@ -0,0 +1,294 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) 2024 The EVAS Intelligence Inc. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//============================================================================== + +#include "epu/memory.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" + +#define GEN_PASS_DEF_SETMEMREFSCOPE +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { + +namespace { + +IntegerAttr createMemScopeAttr(MLIRContext *context, + ::mlir::ev::MemScope scope) { + return IntegerAttr::get(IntegerType::get(context, 64), + static_cast(scope)); +} + +IntegerAttr createMemScopeAttr(MLIRContext *context, int64_t scope) { + return IntegerAttr::get(IntegerType::get(context, 64), scope); +} + +MemRefType createMemRefTypeWithScope(MemRefType originalType, + ::mlir::ev::MemScope scope) { + return MemRefType::get(originalType.getShape(), originalType.getElementType(), + originalType.getLayout(), + createMemScopeAttr(originalType.getContext(), scope)); +} + +MemRefType createMemRefTypeWithScope(MemRefType originalType, int64_t scope) { + return MemRefType::get(originalType.getShape(), originalType.getElementType(), + originalType.getLayout(), + createMemScopeAttr(originalType.getContext(), scope)); +} + +class MemRefScopeTypeConverter : public TypeConverter { +public: + MemRefScopeTypeConverter() { + addConversion([](Type type) { return type; }); + + addConversion([](BaseMemRefType memrefType) -> Type { + return cast( + memrefType + .clonePtrWith(createMemScopeAttr(memrefType.getContext(), + ::mlir::ev::MemScope::DDR), + std::nullopt) + .value()); + }); + + addConversion([this](MemRefType memrefType) -> Type { + return this->convertMemRefType(memrefType); + }); + + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + } + +private: + MemRefType convertMemRefType(MemRefType memrefType) { + // if (!memrefType.hasRank() || memrefType.getRank() == 0) { + // return MemRefType::get(memrefType.getShape(), + // memrefType.getElementType(), + // memrefType.getLayout(), + // IntegerAttr::get(IntegerType::get(memrefType.getContext(), + // 64), + // static_cast(::mlir::ev::MemScope::DDR))); + // } + + + if (memrefType.getMemorySpace() && memrefType.getMemorySpaceAsInt() > 0) { + return memrefType; + } + + return createMemRefTypeWithScope(memrefType, ::mlir::ev::MemScope::DDR); + } +}; + +struct AllocOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto memrefType = cast(op.getType()); + auto newType = + createMemRefTypeWithScope(memrefType, ::mlir::ev::MemScope::MM); + + auto newAlloc = rewriter.create(op.getLoc(), newType); + rewriter.replaceOp(op, newAlloc); + return success(); + } +}; + +LogicalResult +inferResultsToMemRefTypes(Operation *op, ArrayRef operands, + SmallVector &inferredReturnTypes) { + auto inferTypeInterface = dyn_cast(op); + if (inferTypeInterface) { + NamedAttrList attrList(op->getAttrs()); + DictionaryAttr attributes = attrList.getDictionary(op->getContext()); + if (failed(inferTypeInterface.inferReturnTypes( + op->getContext(), op->getLoc(), operands, attributes, + op->getPropertiesStorage(), op->getRegions(), + inferredReturnTypes))) { + return failure(); + } + return success(); + } + + + int64_t firstOperandScope = ::mlir::ev::MemScope::DDR; + for (const Value &operand : operands) { + if (auto memRefType = dyn_cast(operand.getType())) { + if (memRefType.getMemorySpace()) { + firstOperandScope = memRefType.getMemorySpaceAsInt(); + } + break; + } + } + + for (Type resultType : op->getResultTypes()) { + if (auto memrefType = dyn_cast(resultType)) { + Type convertedType = + createMemRefTypeWithScope(memrefType, firstOperandScope); + inferredReturnTypes.push_back(convertedType); + } else { + inferredReturnTypes.push_back(resultType); + } + } + + return success(); +} + +struct GlobalOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto memrefType = cast(op.getType()); + auto newType = + createMemRefTypeWithScope(memrefType, ::mlir::ev::MemScope::DDR); + auto newTypeAttr = TypeAttr::get(newType); + auto newGlobal = rewriter.create( + op.getLoc(), + op.getSymNameAttr(), + op.getSymVisibilityAttr(), + newTypeAttr, + op.getInitialValue().value(), + op.getConstantAttr(), + op.getAlignmentAttr() + ); + + rewriter.replaceOp(op, newGlobal); + return success(); + } + }; + +struct InferTypeOpConverter : public ConversionPattern { + InferTypeOpConverter(MLIRContext *context) + : ConversionPattern(MatchAnyOpTypeTag(), 0, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector inferredReturnTypes; + if (failed(inferResultsToMemRefTypes(op, operands, inferredReturnTypes))) { + return failure(); + } + + OperationState newOpState(op->getLoc(), op->getName()); + newOpState.addOperands(operands); + newOpState.addTypes(inferredReturnTypes); + newOpState.addAttributes(op->getAttrs()); + newOpState.addSuccessors(op->getSuccessors()); + + for (auto ®ion : op->getRegions()) { + newOpState.addRegion(); + } + + Operation *newOp = rewriter.create(newOpState); + + for (auto [oldRegion, newRegion] : + llvm::zip(op->getRegions(), newOp->getRegions())) { + rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.end()); + } + + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// bool applyInferMemrefScopeConversion(ModuleOp moduleOp) { +// ConversionTarget target(*moduleOp.getContext()); +// target.markUnknownOpDynamicallyLegal([](Operation *op) { +// auto types = llvm::concat(op->getOperandTypes(), op->getResultTypes()); +// for (auto type : types) { +// if (auto memrefType = dyn_cast(type)) { +// if (!memrefType.getMemorySpace() || memrefType.getMemorySpaceAsInt() == 0) { +// return false; +// } +// } +// } +// return true; +// }); +// target.addLegalOp(); + +// RewritePatternSet patterns(moduleOp.getContext()); +// patterns.add(moduleOp.getContext()); +// patterns.add(moduleOp.getContext()); +// if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { +// return false; +// } +// return true; +// } + +struct SetMemRefScopePass + : public ::impl::SetMemRefScopeBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + MemRefScopeTypeConverter typeConverter; + + ConversionTarget target(getContext()); + + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + if (auto globalOp = dyn_cast(op)) { + return typeConverter.isLegal(globalOp.getType()); + } + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + target.addLegalOp(); + RewritePatternSet patterns(&getContext()); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + } +}; + +} // namespace + +std::unique_ptr createSetMemRefScopePass() { + return std::make_unique(); +} + +} // namespace mlir::triton::ev diff --git a/third_party/evas/lib/Transform/Linalg/SplitComputationalOp.cpp b/third_party/evas/lib/Transform/Linalg/SplitComputationalOp.cpp new file mode 100644 index 0000000000..2421819a90 --- /dev/null +++ b/third_party/evas/lib/Transform/Linalg/SplitComputationalOp.cpp @@ -0,0 +1,664 @@ +#include + +#include "epu/memory.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton-shared/Transform/common_utils.h" +#define GEN_PASS_DEF_SPLITCOMPUTATIONALOP +#include "evas/Transform/Linalg/Passes.h.inc" + +namespace mlir::triton::ev { +namespace { +using mlir::func::FuncOp; +using mlir::linalg::LinalgOp; + +static constexpr llvm::StringRef kSchedulePrimitive = "schedule_primitive"; + +using Cluster = llvm::SmallVector; +raw_ostream &operator<<(raw_ostream &os, const Cluster &cluster) { + os << "[\n"; + for (size_t i = 0; i < cluster.size(); ++i) { + if (i != 0) { + os << ", "; // Separate elements with a comma + } + if (cluster[i] != nullptr) { + // Assuming you want to output the address of the Operation object + os << "Operation " << i << ": " << *(cluster[i]) + << "\n"; // Output the address (or use any relevant member function) + // Alternatively, you can call a member function to display additional + // information if desired cluster[i]->print(); + } else { + os << "nullptr\n"; // Handle nullptr entries + } + } + os << "]\n"; + return os; +} + +void reorderCluster(Cluster &cluster) { + if (cluster.empty()) + return; + std::sort(cluster.begin(), cluster.end(), + [](Operation *a, Operation *b) { return a->isBeforeInBlock(b); }); +} + +bool isComputationalOp(Operation *op) { + StringRef opName = op->getName().getStringRef(); + return opName.starts_with("linalg") && opName != "linalg.yield" && opName != "linalg.generic"; +} + +bool isOnlyUser(Operation *A, Operation *B) { + // Check if A has any results + if (A->getNumResults() == 0) + return false; + // Check all results of A + for (Value result : A->getResults()) { + // If any result has no uses, return false + if (result.use_empty()) + return false; + // Check if all uses of this result are in operation B + for (OpOperand &use : result.getUses()) { + if (use.getOwner() != B) + return false; + } + } + // All results of A are only used by B + return true; +} + +bool usedOnlyInCluster(Operation *op, const Cluster &cluster) { + for (Value result : op->getResults()) { + for (OpOperand &use : result.getUses()) { + if (!llvm::is_contained(cluster, use.getOwner())) + return false; + } + } + return true; +} + +SmallVector getInputsOfCluster(const Cluster &cluster) { + llvm::SmallVector inputs; + llvm::SmallDenseSet inputSet; + llvm::SmallDenseSet opSet; + bool hasScatter = false; + + for (Operation *op : cluster) { + if (isa(op)) { + hasScatter = true; + } + bool inserted = opSet.insert(op).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + } + + for (Operation *op : cluster) { + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (opSet.find(defOp) != opSet.end()) { + // skip if defining op is in the cluster + continue; + } + if (inputSet.insert(operand).second) { + inputs.push_back(operand); + } + } + if (hasScatter) { + // Since the input of scatter is also the output, the input will be added + // again as the output parameter. + inputs.push_back(op->getOperand(0)); + } + } + return inputs; +} + +SmallVector getOutputsOfCluster(const Cluster &cluster) { + llvm::SmallVector outputs; + llvm::SmallDenseSet opSet; + for (Operation *op : cluster) { + // Should add all the operations recursively because a value might be used + // by an operation of an inner region. + op->walk([&](Operation *innerOp) { + bool inserted = opSet.insert(innerOp).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + }); + } + + for (Operation *op : cluster) { + for (Value result : op->getResults()) { + bool hasExternalUser = + llvm::any_of(result.getUses(), [&](OpOperand &use) { + return !opSet.count(use.getOwner()); + }); + if (hasExternalUser) { + outputs.push_back(result); + } + } + } + return outputs; +} + +Operation *getFirstOpInCluster(const Cluster &cluster) { + Operation *firstOp = *std::min_element( + cluster.begin(), cluster.end(), + [](Operation *x, Operation *y) { return x->isBeforeInBlock(y); }); + return firstOp; +} + +Operation *getLastOpInCluster(const Cluster &cluster) { + Operation *lastOp = *std::max_element( + cluster.begin(), cluster.end(), + [](Operation *x, Operation *y) { return x->isBeforeInBlock(y); }); + return lastOp; +} + +void moveConsumer(const Cluster &cluster) { + Operation *firstOp = getFirstOpInCluster(cluster); + Operation *lastOp = getLastOpInCluster(cluster); + + llvm::SmallDenseSet fusedSet(cluster.begin(), cluster.end()); + llvm::SmallDenseSet consumerSet; + + llvm::SmallVector consumersVec; + auto firstIter = firstOp->getIterator(); + auto lastIter = lastOp->getIterator(); + + for (Operation &curOp : llvm::make_range(firstIter, lastIter)) { + // isn't fused op && consumer's op + // move this after fusion op + if (!fusedSet.contains(&curOp)) { + // fused op's consumer or consumer's consumer + bool isConsumer = + llvm::any_of(curOp.getOperands(), [&fusedSet, &consumerSet](Value v) { + auto op = v.getDefiningOp(); + return fusedSet.contains(op) || consumerSet.contains(op); + }); + if (isConsumer) { + consumerSet.insert(&curOp); + consumersVec.push_back(&curOp); + } + } + } + + for (auto op : llvm::reverse(consumersVec)) { + op->moveAfter(lastOp); + } +} + +bool isCallOpRet(Value v) { + auto op = v.getDefiningOp(); + if (!op) + return false; + if (isa(op)) { + return true; + } + for (auto operand : op->getOperands()) { + if (isCallOpRet(operand)) + return true; + } + return false; +} + +bool isScalarType(Type type) { + return isa(type) || isa(type) || + isa(type) || isa(type); +} + +Operation *findDstInput(Operation *op) { + if (!op) + return nullptr; + if (isa_and_nonnull(op)) { + return op; + } + if (isa_and_nonnull(op)) { + return op; + } + if (op->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) { + return findDstInput(op->getOperand(0).getDefiningOp()); + } else if (auto fillOp = llvm::dyn_cast_or_null(op)) { + return findDstInput(fillOp.getOperand(1).getDefiningOp()); + } + return nullptr; +} + +Cluster findDestinationOps(Operation *op, + const llvm::SmallDenseSet &dstInputSet, + const Cluster &cls) { + Cluster destinationOps; + // Process each operand of the operation + for (Value input : op->getOperands()) { + Operation *inputOp = input.getDefiningOp(); + // Skip if this is an output dstInput + if (!inputOp || dstInputSet.contains(inputOp)) + continue; + // Skip if the input op is not used only in the cluster + if (!usedOnlyInCluster(inputOp, cls)) + continue; + // If this is an allocation operation, add it directly + if (isa(inputOp)) { + destinationOps.push_back(inputOp); + continue; + } + // Recursively process tensor dialect operations + if (inputOp->getDialect()->getNamespace() == + tensor::TensorDialect::getDialectNamespace()) { + // Get destination ops from the tensor op + Cluster attachedOps = findDestinationOps(inputOp, dstInputSet, cls); + // If we found destination ops, add them and the tensor op + if (!attachedOps.empty()) { + destinationOps.append(attachedOps); + destinationOps.push_back(inputOp); + } + } + } + + return destinationOps; +} + +std::string getFuncName(int clusterIdx) { + std::ostringstream nameStream; + nameStream << "sub_kernel_" << clusterIdx; + return nameStream.str(); +} + +class SplitComputationalOpPass + : public ::impl::SplitComputationalOpBase { +public: + int getClusterIndex(Operation *op) { + for (auto indexedCluster : llvm::enumerate(clusters)) { + auto cluster = indexedCluster.value(); + for (auto clusterOp : cluster) { + if (clusterOp == op) + return indexedCluster.index(); + } + } + return -1; + } + + int mergeClusters(int c1, int c2) { + if (c1 == c2) + return c1; + if (c1 > c2) { + return mergeClusters(c2, c1); + } + Cluster cluster2 = clusters[c2]; + clusters.erase(clusters.begin() + c2); + clusters[c1].append(cluster2.begin(), cluster2.end()); + return c1; + } + + void InitClusters(FuncOp funcOp) { + funcOp.walk([this](Operation *op) { + if (isComputationalOp(op)) { + Cluster newCls; + newCls.push_back(op); + clusters.push_back(newCls); + } + return WalkResult::advance(); + }); + } + + bool ConnnectedTo(const Cluster &clsA, const Cluster &clsB) { + auto valuesA = getOutputsOfCluster(clsA); + auto valuesB = getInputsOfCluster(clsB); + for (auto in : valuesA) { + for (auto out : valuesB) { + if (in == out) + return true; + } + } + return false; + } + + bool HasSingleOutput(const Cluster &clsA) { + auto outputs = getOutputsOfCluster(clsA); + return outputs.size() <= 1; + } + + Cluster getMergedCls(const Cluster &clsA, const Cluster &clsB) { + Cluster ret = clsA; + ret.append(clsB.begin(), clsB.end()); + return ret; + } + + template bool hasLinalgOp(const Cluster &cls) { + for (auto op : cls) { + if (isa(op)) + return true; + } + return false; + } + + template bool IsolatedPattern(const Cluster &cls) { + if (!hasLinalgOp(cls)) + return true; + for (auto op : cls) { + if (isa(op) && !(isa(op))) + return false; + } + return true; + } + + template bool TryIsolatedPattern(const Cluster &cls) { + return (... && IsolatedPattern(cls)); + } + + bool TryRestrictedFusePattern(const Cluster &cls) { + // Multi-output is not supported by evofc for now + if (!HasSingleOutput(cls)) + return false; + + for (Operation *op : cls) { + bool isReduceOp = true; + SmallVector dims{}, shapes{}; + auto inputDimFunc = + [&](auto op) { + dims.push_back(op.getDimension()); + auto shape = op.getInput().getType().getShape(); + shapes.assign(shape.begin(), shape.end()); + }; + auto inputDimsFunc = [&](auto op) { + dims.assign(op.getDimensions().begin(), op.getDimensions().end()); + auto shape = op.getInput().getType().getShape(); + shapes.assign(shape.begin(), shape.end()); + }; + auto inputsDimsFunc = [&](auto op) { + dims.assign(op.getDimensions().begin(), op.getDimensions().end()); + auto type = op.getInputs()[0].getType(); + if (auto shapedType = dyn_cast(type)) { + auto shape = shapedType.getShape(); + shapes.assign(shape.begin(), shape.end()); + } else if (auto tensorType = dyn_cast(type)) { + auto shape = tensorType.getShape(); + shapes.assign(shape.begin(), shape.end()); + } else if (auto memrefType = dyn_cast(type)) { + auto shape = memrefType.getShape(); + shapes.assign(shape.begin(), shape.end()); + } + }; + llvm::TypeSwitch(op) + .Case(inputDimFunc) + .Case(inputDimFunc) + .Case(inputsDimsFunc) + .Case(inputDimFunc) + .Case(inputDimFunc) + .Case(inputsDimsFunc) + .Case(inputsDimsFunc) + .Case(inputDimFunc) + .Case(inputDimFunc) + .Case(inputDimsFunc) + .Case(inputDimFunc) + .Case(inputDimsFunc) + .Case(inputDimsFunc) + .Default([&](auto op) { isReduceOp = false; }); + if (!isReduceOp) + continue; + + bool parallelizable = false; + for (int i = 0; i < shapes.size(); i++) { + auto it = std::find(dims.begin(), dims.end(), i); + if (it == dims.end() && shapes[i] != 1) { + parallelizable = true; + break; + } + } + auto *ctx = op->getContext(); + op->setAttr("isParallelizable", mlir::BoolAttr::get(ctx, parallelizable)); + if (!parallelizable) { + return false; + } + } + + if (!TryIsolatedPattern(cls)) + return false; + + // transpose + return true; + } + + bool CanFuseTo(const Cluster &clsA, const Cluster &clsB) { + if (!ConnnectedTo(clsA, clsB)) + return false; + + Operation *lastOpA = getLastOpInCluster(clsA); + Operation *firstOpB = getFirstOpInCluster(clsB); + Cluster clsMid; + // Check if lastOpA and firstOpB are in the same block + if (lastOpA->getBlock() != firstOpB->getBlock()) + return false; + for (auto &op : llvm::make_range(std::next(lastOpA->getIterator()), + firstOpB->getIterator())) { + clsMid.push_back(&op); + } + // annotate op to cut the fusion + if (utils::getAnnotation(lastOpA)) return false; + // check if the middle cluster is connected to both clsA and clsB + if (!clsMid.empty() && ConnnectedTo(clsA, clsMid) && + ConnnectedTo(clsMid, clsB)) { + return false; + } + auto tryMerged = getMergedCls(clsA, clsB); + return TryRestrictedFusePattern(tryMerged); + } + + void FuseClustersWithDefUse() { + if (clusters.size() <= 1) + return; + for (size_t index = 0; index < clusters.size() - 1; ++index) { + if (CanFuseTo(clusters[index], clusters[index + 1])) { + (void)mergeClusters(index, index + 1); + FuseClustersWithDefUse(); + return; + } + } + } + + Cluster getAttachedCluster(const Cluster &cls) { + Cluster ret = cls; + auto outputs = getOutputsOfCluster(cls); + auto outputsSet = + llvm::SmallDenseSet(outputs.begin(), outputs.end()); + llvm::SmallDenseSet dstInputSet; + // find all the dst input ops that correspond to the cluster outputs + for (auto op : cls) { + auto dstOp = cast(op); + for (auto [idx, output] : llvm::enumerate(op->getResults())) { + if (outputsSet.contains(output)) { + auto init = + findDstInput(dstOp.getDpsInitOperand(idx)->get().getDefiningOp()); + dstInputSet.insert(init); + outputToDstInput[output] = init; + } + } + } + for (auto op : cls) { + auto dst_ops = findDestinationOps(op, dstInputSet, cls); + ret.append(dst_ops); + } + reorderCluster(ret); + return ret; + } + + + void setMemscopeForAllocTensorOp(bufferization::AllocTensorOp allocOp, OpBuilder &b) { + if (auto annotateOp = utils::getAnnotation(allocOp.getOperation())){ + auto meminfo = annotateOp.getMeminfo(); + allocOp.setMemorySpaceAttr(b.getI64IntegerAttr((int64_t)meminfo.getScope())); + } else { + allocOp.setMemorySpaceAttr(b.getI64IntegerAttr((int64_t)mlir::triton::MemScope::L2)); + } + if (mlir::isa(allocOp.getType())) { + auto tensorType = mlir::cast(allocOp.getType()); + if (tensorType.getRank() == 0 || + (tensorType.getRank() == 1 && tensorType.getShape()[0] == 1) || + (tensorType.getRank() == 2 && tensorType.getShape()[0] == 1 && tensorType.getShape()[1] == 1)) { + allocOp.setMemorySpaceAttr(b.getI64IntegerAttr((int64_t)mlir::triton::MemScope::DDR)); + } + } + } + + bool isAnnotatedPrefetch(Operation *op) { + if (auto annotateOp = utils::getAnnotation(op)){ + auto meminfo = annotateOp.getMeminfo(); + return meminfo.getPrefetch(); + } + return false; + } + + void annotatePrefetchToSubKernel(bufferization::AllocTensorOp allocOp, func::CallOp callOp, OpBuilder &b) { + if (isAnnotatedPrefetch(allocOp)) { + callOp->setAttr(mlir::ev::prefetchName, b.getBoolAttr(true)); + } + } + + func::FuncOp createFuncOpWithCluster(OpBuilder &b, StringRef subFnName, + ValueRange inputs, ValueRange outputs, + const Cluster &cluster, + Operation *insertionPoint) { + Operation *lastOp = getLastOpInCluster(cluster); + llvm::SmallVector locations; + locations.reserve(cluster.size()); + for (Operation *op : cluster) { + locations.push_back(op->getLoc()); + } + Location fusedLoc = FusedLoc::get(lastOp->getContext(), locations); + + llvm::SmallVector outputTypes; + outputTypes.reserve(outputs.size()); + for (Value v : outputs) { + outputTypes.push_back(v.getType()); + } + llvm::SmallVector inputTypes; + inputTypes.reserve(inputs.size()); + for (Value v : inputs) { + inputTypes.push_back(v.getType()); + } + + moveConsumer(cluster); + + auto subFnType = b.getFunctionType(inputTypes, outputTypes); + b.setInsertionPoint(insertionPoint); + func::FuncOp subFnOp = + b.create(fusedLoc, subFnName, subFnType); + subFnOp.setSymVisibility("private"); + b.setInsertionPoint(lastOp); + auto callOp = b.create(fusedLoc, subFnOp, inputs); + callOp->setAttr(kSchedulePrimitive, b.getBoolAttr(true)); + // callOp->setAttr(mlir::ev::addrName, + // b.getArrayAttr(SmallVector( + // callOp.getNumResults(), b.getI64IntegerAttr(-1)))); + Block *block = subFnOp.addEntryBlock(); + b.setInsertionPoint(block, block->end()); + IRMapping bvm; + for (auto inputAndArg : llvm::zip(inputs, subFnOp.getArguments())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + for (Operation *op : cluster) { + b.clone(*op, bvm); + } + llvm::SmallVector funcReturns; + for (Value output : outputs) { + funcReturns.push_back(bvm.lookupOrDefault(output)); + } + b.create(fusedLoc, funcReturns); + + for (auto outputAndResult : llvm::zip(outputs, callOp.getResults())) { + Value output = std::get<0>(outputAndResult); + // replace the use of output with the destination alloc op + Operation *dstInputOp = outputToDstInput[output]; + Value dstInputValue = dstInputOp->getResult(0); + Value callResult = std::get<1>(outputAndResult); + for (OpOperand &use : llvm::make_early_inc_range(output.getUses())) { + use.set(dstInputValue); + } + // 只对 bufferization::AllocTensorOp 设置内存空间和预取属性 + if (auto tensorAllocOp = dyn_cast(dstInputOp)) { + setMemscopeForAllocTensorOp(tensorAllocOp, b); + // todo:需要考虑totensor的情况 + annotatePrefetchToSubKernel(tensorAllocOp, callOp, b); + } + } + + // erase dead ops in the end + for (Operation *op : llvm::reverse(cluster)) { + if (op->use_empty()) { + op->erase(); + } + } + + return subFnOp; + } + FailureOr createFuncOpWithCluster(OpBuilder &b, + StringRef subFnName, + const Cluster &cluster, + Operation *insertionPoint) { + auto attachedCluster = getAttachedCluster(cluster); + llvm::SmallVector inputs = getInputsOfCluster(attachedCluster); + llvm::SmallVector outputs = getOutputsOfCluster(attachedCluster); + return createFuncOpWithCluster(b, subFnName, inputs, outputs, + attachedCluster, insertionPoint); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + // set funcName fixed for finding the outer func in further process + auto f = *(m.getOps().begin()); + const std::string funcName = "kernel"; + f.setName(funcName); + InitClusters(f); + FuseClustersWithDefUse(); + OpBuilder b(f); + + SymbolTable symTable(m); + for (auto c : llvm::enumerate(clusters)) { + llvm::outs() << c.value() << "\n"; + FailureOr subFnOp = createFuncOpWithCluster( + b, getFuncName(c.index()), c.value(), f.getOperation()); + assert(mlir::succeeded(subFnOp) && "create FuncOp failed"); + symTable.insert(*subFnOp); + } + } + +private: + SmallVector clusters; + llvm::SmallDenseMap outputToDstInput; +}; + +} // namespace + +std::unique_ptr createSplitComputationalOpPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::ev diff --git a/third_party/evas/patches/triton-shared-llvm22-compat.patch b/third_party/evas/patches/triton-shared-llvm22-compat.patch new file mode 100644 index 0000000000..700ef6c759 --- /dev/null +++ b/third_party/evas/patches/triton-shared-llvm22-compat.patch @@ -0,0 +1,75 @@ +diff --git a/include/triton-shared/Analysis/UseAnalysis.h b/include/triton-shared/Analysis/UseAnalysis.h +--- a/include/triton-shared/Analysis/UseAnalysis.h ++++ b/include/triton-shared/Analysis/UseAnalysis.h +@@ -49,6 +49,7 @@ struct UseInfo : public dataflow::AbstractSparseLattice { + case UseType::MixUse: + return ChangeResult::NoChange; + } ++ return ChangeResult::NoChange; + } + + ChangeResult meet(const AbstractSparseLattice &other) override { +@@ -87,12 +87,6 @@ public: + + void visitCallOperand(OpOperand &operand) override { return; } + +- void +- visitNonControlFlowArguments(RegionSuccessor &successor, +- ArrayRef arguments) override { +- return; +- } +- + void setToExitState(UseInfo *lattice) override { + lattice->type = UseType::Undefined; + } + +diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp +--- a/lib/Analysis/MaskAnalysis.cpp ++++ b/lib/Analysis/MaskAnalysis.cpp +@@ -65,7 +65,8 @@ tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + +- auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, dims); ++ auto dstType = ++ tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); + + return tensor::ExtractSliceOp::create(builder, loc, dstType, source, offsets, + dims, strides); +diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +--- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp ++++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +@@ -974,7 +974,8 @@ private: + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector strides(rank, b.getIndexAttr(1)); + +- auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, dims); ++ auto dstType = ++ tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); + + return tensor::ExtractSliceOp::create(b, loc, dstType, source, offsets, + dims, strides); +diff --git a/triton_shared.cc b/triton_shared.cc +--- a/triton_shared.cc ++++ b/triton_shared.cc +@@ -25,10 +25,9 @@ + #include "mlir/Conversion/Passes.h" + #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" +-#include "mlir/Dialect/Affine/Transforms/Passes.h" ++#include "mlir/Dialect/Affine/Passes.h" + #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" + #include "mlir/Dialect/Arith/Transforms/Passes.h" +-#include "mlir/Dialect/Bufferization/Extensions/AllExtensions.h" + #include "mlir/Dialect/Bufferization/IR/Bufferization.h" + #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" + #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +@@ -70,8 +69,6 @@ void init_triton_triton_shared(py::module &&m) { + cf::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + +- mlir::bufferization::registerAllExtensions(registry); +- + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + });