diff --git a/.github/workflows/llvm-build.yml b/.github/workflows/llvm-build.yml
index b3b7fffae40e..4e944acbeb4a 100644
--- a/.github/workflows/llvm-build.yml
+++ b/.github/workflows/llvm-build.yml
@@ -159,8 +159,8 @@ jobs:
         cp -r /usr/aarch64-linux-gnu/lib ./arm-sysroot
         cp -r /usr/aarch64-linux-gnu/include ./arm-sysroot
         LINKER=$(pwd)/arm-sysroot/lib/ld-linux-aarch64.so.1
-        wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_13.2.0-7_amd64.deb
-        dpkg-deb -x gcc-aarch64-linux-gnu_13.2.0-7_amd64.deb ./arm-sysroot
+        wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb
+        dpkg-deb -x gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb ./arm-sysroot
         export LD_LIBRARY_PATH=$(pwd)/arm-sysroot/lib:$LD_LIBRARY_PATH
         sudo ln -s $LINKER /lib/ld-linux-aarch64.so.1
         SYSROOT="$(pwd)/arm-sysroot"
diff --git a/BUILD b/BUILD
new file mode 100644
index 000000000000..f74ef119cd09
--- /dev/null
+++ b/BUILD
@@ -0,0 +1,900 @@
+# This package imports OpenAI's Triton (https://github.com/openai/triton).
+#
+# There are two versions of Triton in google3 at the moment. The older version
+# can be found at //third_party/py/triton. This is the MLIR-based version close
+# to head. We expect to transition users to this version in the following
+# weeks.
+#
+# There is no SLA associated with this package and it may get broken by LLVM
+# imports at any time.
+
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license")
+
+package(
+    # copybara:uncomment_begin
+    # default_applicable_licenses = [":license"],
+    # default_compatible_with = ["//buildenv/target:gce"],
+    # default_visibility = [
+        # # Add your project here if you need to depend on Triton's C++ sources.
+        # # Add a point of contact we can reach out to when needed in the comment.
+        # #
+        # # If you need to use the Python fronted, add your project to
+        # # google3/third_party/py/triton/BUILD instead.
+        # #
+        # # By adding your project here, you agree to the Triton SLA:  go/triton-google3-sla
+        # "//third_party/py/jax:__subpackages__",  # cjfj@
+        # "//third_party/tensorflow/compiler/xla:__subpackages__",  # bchetioui@
+        # "//platforms/xla/experimental/gpu:__subpackages__",  # csigg@
+        # # Triton-internal visibility
+        # "//:__subpackages__",
+    # ],
+    # copybara:uncomment_end_and_comment_begin
+    default_visibility = ["//visibility:public"],
+    # copybara:comment_end
+    # TODO(csigg): fix and remove
+    features = [
+        "-parse_headers",
+        "-use_header_modules",
+    ],
+)
+
+# copybara:uncomment_begin
+# license(name = "license")
+# 
+# licenses(["notice"])
+# 
+# exports_files(["LICENSE"])
+# copybara:uncomment_end
+
+config_setting(
+    name = "compiler_is_msvc",
+    flag_values = {
+        # copybara:comment_begin
+        "@bazel_tools" +
+        # copybara:comment_end
+        "//tools/cpp:compiler": "msvc-cl",
+    },
+)
+
+# TODO(csigg): fix, enable error upstream, remove.
+_no_unused_variable = select({
+    ":compiler_is_msvc": [],
+    "//conditions:default": ["-Wno-unused-variable"],
+})
+
+td_library(
+    name = "td_files",
+    srcs = glob(["include/triton/**/*.td"]),
+    includes = ["include"],
+    deps = [
+        "@llvm-project//mlir:ArithOpsTdFiles",
+        "@llvm-project//mlir:CastInterfacesTdFiles",
+        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+        "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
+        "@llvm-project//mlir:FunctionInterfacesTdFiles",
+        "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+        "@llvm-project//mlir:LLVMOpsTdFiles",
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:PassBaseTdFiles",
+        "@llvm-project//mlir:SideEffectInterfacesTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "triton_attr_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-attrdef-decls"],
+            "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc",
+        ),
+        (
+            ["--gen-attrdef-defs"],
+            "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_dialect_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-dialect-decls"],
+            "include/triton/Dialect/Triton/IR/Dialect.h.inc",
+        ),
+        (
+            ["--gen-dialect-defs"],
+            "include/triton/Dialect/Triton/IR/Dialect.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_interfaces_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-attr-interface-decls"],
+            "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc",
+        ),
+        (
+            ["--gen-attr-interface-defs"],
+            "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_ops_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-enum-decls"],
+            "include/triton/Dialect/Triton/IR/OpsEnums.h.inc",
+        ),
+        (
+            ["--gen-enum-defs"],
+            "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc",
+        ),
+        (
+            ["--gen-op-decls"],
+            "include/triton/Dialect/Triton/IR/Ops.h.inc",
+        ),
+        (
+            ["--gen-op-defs"],
+            "include/triton/Dialect/Triton/IR/Ops.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/IR/TritonOps.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_types_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-typedef-decls"],
+            "include/triton/Dialect/Triton/IR/Types.h.inc",
+        ),
+        (
+            ["--gen-typedef-defs"],
+            "include/triton/Dialect/Triton/IR/Types.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_transforms_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=Triton",
+            ],
+            "include/triton/Dialect/Triton/Transforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/Transforms/Passes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_combine_inc_gen",
+    # The generated file is #included without relative path.
+    strip_include_prefix = "lib/Dialect/Triton/Transforms",
+    tbl_outs = [
+        (
+            ["--gen-rewriters"],
+            "lib/Dialect/Triton/Transforms/TritonCombine.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "lib/Dialect/Triton/Transforms/Combine.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_gpu_attr_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-attrdef-decls"],
+            "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc",
+        ),
+        (
+            ["--gen-attrdef-defs"],
+            "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc",
+        ),
+        (
+            ["--gen-enum-decls"],
+            "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc",
+        ),
+        (
+            ["--gen-enum-defs"],
+            "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc",
+        ),
+        (
+            ["--gen-attr-interface-decls"],
+            "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc",
+        ),
+        (
+            ["--gen-attr-interface-defs"],
+            "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_gpu_dialect_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-dialect-decls"],
+            "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc",
+        ),
+        (
+            ["--gen-dialect-defs"],
+            "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_gpu_ops_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-op-decls"],
+            "include/triton/Dialect/TritonGPU/IR/Ops.h.inc",
+        ),
+        (
+            ["--gen-op-defs"],
+            "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_gpu_types_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-typedef-decls"],
+            "include/triton/Dialect/TritonGPU/IR/Types.h.inc",
+        ),
+        (
+            ["--gen-typedef-defs"],
+            "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_gpu_transforms_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonGPU",
+            ],
+            "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_nvidia_gpu_attr_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-attrdef-decls"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc",
+        ),
+        (
+            ["--gen-attrdef-defs"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc",
+        ),
+        (
+            ["--gen-enum-decls"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc",
+        ),
+        (
+            ["--gen-enum-defs"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_nvidia_gpu_dialect_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-dialect-decls"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc",
+        ),
+        (
+            ["--gen-dialect-defs"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_nvidia_gpu_ops_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-op-decls"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc",
+        ),
+        (
+            ["--gen-op-defs"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_nvidia_gpu_types_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-typedef-decls"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc",
+        ),
+        (
+            ["--gen-typedef-defs"],
+            "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_nvidia_gpu_transforms_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonNvidiaGPU",
+            ],
+            "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonToTritonGPU",
+            ],
+            "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_target_llvmir_passes_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonLLVMIR",
+            ],
+            "include/triton/Target/LLVMIR/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Target/LLVMIR/Passes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonGPUToLLVM",
+            ],
+            "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_type_interfaces_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-type-interface-decls"],
+            "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc",
+        ),
+        (
+            ["--gen-type-interface-defs"],
+            "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td",
+    deps = ["td_files"],
+)
+
+cc_library(
+    name = "TritonAnalysis",
+    srcs = [
+        "lib/Analysis/Alias.cpp",
+        "lib/Analysis/Allocation.cpp",
+        "lib/Analysis/Membar.cpp",
+        # Part of TritonDialects compilation unit to avoid circular dependencies.
+        # "lib/Analysis/Utility.cpp",
+        # "lib/Analysis/AxisInfo.cpp",
+    ],
+    hdrs = [
+        "include/triton/Analysis/Alias.h",
+        "include/triton/Analysis/Allocation.h",
+        "include/triton/Analysis/Membar.h",
+        # Part of TritonDialects compilation unit to avoid circular dependencies.
+        # "include/triton/Analysis/AxisInfo.h",
+        # "include/triton/Analysis/Utility.h",
+        "include/triton/Conversion/MLIRTypes.h",
+        "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h",
+        "include/triton/Conversion/TritonGPUToLLVM/Utility.h",
+        "include/triton/Dialect/TritonGPU/Transforms/Utility.h",
+    ],
+    copts = _no_unused_variable,
+    deps = [
+        ":TritonDialects",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ControlFlowInterfaces",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
+    name = "TritonDialects",
+    srcs = glob([
+        "lib/Dialect/Triton/IR/*.cpp",
+        "lib/Dialect/TritonGPU/IR/*.cpp",
+        "lib/Dialect/TritonNvidiaGPU/IR/*.cpp",
+        "lib/Tools/*.cpp",
+    ]) + [
+        "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h",  # Avoid circular dependency.
+        "lib/Analysis/AxisInfo.cpp",  # Avoid circular dependency.
+        "lib/Analysis/Utility.cpp",  # Avoid circular dependency.
+        "lib/Dialect/TritonGPU/Transforms/Utility.cpp",  # Avoid circular dependency.
+    ],
+    hdrs = glob([
+        "include/triton/Dialect/Triton/IR/*.h",
+        "include/triton/Dialect/TritonGPU/IR/*.h",
+        "include/triton/Dialect/TritonNvidiaGPU/IR/*.h",
+        "include/triton/Tools/*.h",
+    ]) + [
+        "include/triton/Analysis/AxisInfo.h",  # Avoid circular dependency.
+        "include/triton/Analysis/Utility.h",  # Avoid circular dependency.
+        "include/triton/Dialect/TritonGPU/Transforms/Utility.h",  # Avoid circular dependency.
+    ],
+    copts = select({
+        ":compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+            "-Wno-logical-op-parentheses",
+        ],
+    }),
+    includes = ["include"],
+    deps = [
+        ":triton_dialect_inc_gen",
+        ":triton_gpu_attr_inc_gen",
+        ":triton_gpu_dialect_inc_gen",
+        ":triton_gpu_ops_inc_gen",
+        ":triton_gpu_types_inc_gen",
+        ":triton_interfaces_inc_gen",
+        ":triton_nvidia_gpu_attr_inc_gen",
+        ":triton_nvidia_gpu_dialect_inc_gen",
+        ":triton_nvidia_gpu_ops_inc_gen",
+        ":triton_nvidia_gpu_types_inc_gen",
+        ":triton_ops_inc_gen",
+        ":triton_types_inc_gen",
+        ":triton_type_interfaces_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:ControlFlowInterfaces",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FunctionInterfaces",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InliningUtils",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MathDialect",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@triton//third_party/nvidia:NVGPUDialect",
+        # The following is added to make Utility compile
+        ":TritonTools",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+        "@triton//third_party/f2reduce",
+    ],
+)
+
+cc_library(
+    name = "TritonTransforms",
+    srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]),
+    hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]),
+    copts = _no_unused_variable,
+    deps = [
+        ":TritonDialects",
+        ":triton_combine_inc_gen",
+        ":triton_transforms_inc_gen",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+    ],
+    alwayslink = True,  # TritonDialect uses getCanonicalizationPatterns().
+)
+
+cc_library(
+    name = "TritonGPUTransforms",
+    srcs = glob(
+        [
+            "lib/Dialect/TritonGPU/Transforms/*.cpp",
+            "lib/Dialect/TritonGPU/Transforms/*.h",
+            "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp",
+            "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h",
+        ],
+        exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"],
+    ),
+    hdrs = glob(
+        [
+            "include/triton/Dialect/TritonGPU/Transforms/*.h",
+        ],
+        exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"],
+    ) + [
+        "include/triton/Tools/Sys/GetEnv.hpp",
+    ],
+    copts = select({
+        ":compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-reorder-ctor",
+            "-Wno-return-type",
+            "-Wno-unused-variable",
+        ],
+    }),
+    deps = [
+        ":TritonAnalysis",
+        ":TritonDialects",
+        ":triton_gpu_transforms_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InferTypeOpInterface",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:SCFTransforms",
+        "@llvm-project//mlir:SCFUtils",
+        "@llvm-project//mlir:SideEffectInterfaces",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
+    name = "TritonGPUToLLVM",
+    srcs = glob([
+        "lib/Conversion/TritonGPUToLLVM/*.h",
+        "lib/Conversion/TritonGPUToLLVM/**/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/triton/Tools/Sys/*.hpp",
+        "include/triton/Conversion/TritonGPUToLLVM/*.h",
+    ]),
+    copts = select({
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    includes = ["include"],
+    deps = [
+        ":TritonAnalysis",
+        ":TritonDialects",
+        ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen",
+        ":triton_gpu_attr_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:DataLayoutInterfaces",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
+    name = "TritonNvidiaGPUTransforms",
+    srcs = glob([
+        "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h",
+    ]),
+    copts = select({
+        ":compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-ctad-maybe-unsupported",
+            "-Wno-logical-op-parentheses",
+            "-Wno-non-virtual-dtor",
+            "-Wno-return-type",
+            "-Wno-unused-variable",
+        ],
+    }),
+    includes = ["include"],
+    deps = [
+        ":TritonDialects",
+        ":TritonGPUTransforms",
+        ":triton_nvidia_gpu_transforms_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
+    name = "TritonToTritonGPU",
+    srcs = glob([
+        "lib/Conversion/TritonToTritonGPU/*.h",
+        "lib/Conversion/TritonToTritonGPU/*.cpp",
+    ]),
+    hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]),
+    includes = ["include"],
+    deps = [
+        ":TritonDialects",
+        ":TritonGPUTransforms",
+        ":triton_conversion_triton_to_triton_gpu_passes_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:IndexDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
+    name = "TritonLLVMIR",
+    srcs = glob([
+        "lib/Target/LLVMIR/*.cpp",
+        "lib/Target/LLVMIR/*.h",
+    ]),
+    hdrs = glob(["include/triton/Target/LLVMIR/*.h"]),
+    copts = _no_unused_variable,
+    deps = [
+        ":TritonTransforms",
+        ":triton_target_llvmir_passes_inc_gen",
+        "@llvm-project//llvm:Analysis",
+        "@llvm-project//llvm:BinaryFormat",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:IPO",
+        "@llvm-project//llvm:IRReader",
+        "@llvm-project//llvm:InstCombine",
+        "@llvm-project//llvm:Linker",
+        "@llvm-project//llvm:MC",
+        "@llvm-project//llvm:Passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:Target",
+        "@llvm-project//mlir:ArithToLLVM",
+        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
+        "@llvm-project//mlir:ConversionPasses",
+        "@llvm-project//mlir:ExecutionEngine",
+        "@llvm-project//mlir:ExecutionEngineUtils",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:IndexToLLVM",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:LLVMIRTransforms",
+        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
+        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
+        "@llvm-project//mlir:SCFToControlFlow",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:ToLLVMIRTranslation",
+        "@llvm-project//mlir:Transforms",
+        # copybara:uncomment "//third_party/py/triton/google:find_cuda",
+    ],
+)
+
+cc_library(
+    name = "TritonPTX",
+    srcs = glob([
+        "lib/Target/PTX/*.cpp",
+    ]),
+    hdrs = glob(["include/triton/Target/PTX/*.h"]),
+    deps = ["@llvm-project//llvm:Support"],
+)
+
+cc_library(
+    name = "TritonHSACO",
+    srcs = glob([
+        "lib/Target/HSACO/*.cpp",
+    ]),
+    hdrs = glob(["include/triton/Target/HSACO/*.h"]),
+    deps = [
+        ":TritonLLVMIR",
+        ":TritonTools",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:ExecutionEngine",
+        "@llvm-project//llvm:MC",
+        "@llvm-project//llvm:Scalar",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:Target",
+        "@llvm-project//llvm:TransformUtils",
+        "@llvm-project//mlir:ExecutionEngine",
+        "@llvm-project//mlir:ExecutionEngineUtils",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:ToLLVMIRTranslation",
+    ],
+)
+
+cc_library(
+    name = "TritonTools",
+    hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"],
+)
+
+cc_library(
+    name = "AllPassesAndDialects",
+    srcs = [
+        "include/triton/Conversion/TritonToTritonGPU/Passes.h",
+        "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h",
+    ],
+    hdrs = ["bin/RegisterTritonDialects.h"],
+    includes = ["."],  # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h
+    deps = [
+        ":TritonDialects",
+        ":TritonGPUToLLVM",
+        ":TritonGPUTransforms",
+        ":TritonLLVMIR",
+        ":TritonNvidiaGPUTransforms",
+        ":TritonToTritonGPU",
+        ":TritonTransforms",
+        ":triton_conversion_triton_to_triton_gpu_passes_inc_gen",
+        ":triton_nvidia_gpu_transforms_inc_gen",
+        "@llvm-project//mlir:AllPassesAndDialects",
+        "@triton//test:TritonTestAnalysis",
+        "@triton//third_party/amd:TritonAMDGPUToLLVM",
+        "@triton//third_party/amd:TritonAMDGPUTransforms",
+        "@triton//third_party/nvidia:NVGPUDialect",
+        "@triton//third_party/nvidia:NVGPUToLLVM",
+        "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
+    ],
+)
+
+cc_binary(
+    name = "triton-opt",
+    srcs = [
+        "bin/triton-opt.cpp",
+    ],
+    deps = [
+        ":AllPassesAndDialects",
+        "@llvm-project//mlir:MlirOptLib",
+    ],
+)
+
+cc_binary(
+    name = "triton-llvm-opt",
+    srcs = [
+        "bin/triton-llvm-opt.cpp",
+        "lib/Target/LLVMIR/LLVMPasses.h",
+    ],
+    deps = [
+        ":TritonLLVMIR",
+        "@llvm-project//llvm:CodeGen",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:IRReader",
+        "@llvm-project//llvm:Option",
+        "@llvm-project//llvm:Passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TargetParser",
+    ],
+)
+
+# See go/triton-debug for usage.
+cc_binary(
+    name = "triton-reduce",
+    srcs = ["bin/triton-reduce.cpp"],
+    deps = [
+        ":AllPassesAndDialects",
+        "@llvm-project//mlir:MlirReduceLib",
+    ],
+)
+
+cc_binary(
+    name = "triton-tensor-layout",
+    srcs = ["bin/triton-tensor-layout.cpp"],
+    deps = [
+        ":AllPassesAndDialects",
+        ":TritonDialects",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AsmParser",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+filegroup(
+    name = "metadata-file",
+    srcs = ["METADATA"],
+)
diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt
index 1522498c600e..1995c62654d4 100644
--- a/cmake/llvm-hash.txt
+++ b/cmake/llvm-hash.txt
@@ -1 +1 @@
-ce80c80dca45c7b4636a3e143973e2c6cbdb2884
+7f7f4feaf07dd3bb4b22d0c25d34b6c99c753aa2
diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td
index 6ceb4bc47665..4c709cd4420b 100644
--- a/include/triton/Dialect/Triton/IR/TritonTypes.td
+++ b/include/triton/Dialect/Triton/IR/TritonTypes.td
@@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
 }
 
 // Floating-point Type
-def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
+def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
 def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
 def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
 
diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h
index 0ef59714733d..1ff63697ec0d 100644
--- a/include/triton/Dialect/Triton/IR/Utility.h
+++ b/include/triton/Dialect/Triton/IR/Utility.h
@@ -31,7 +31,11 @@ template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
 
 /// Get the highest power of 2 divisor of an integer.
 template <typename T> T highestPowOf2Divisor(T n) {
-  if (n == 0) {
+  // When n is 0 or min, return the highest power of 2. The min case is handled
+  // separately to avoid underflow when T is a signed integer. Technically
+  // in that case the correct divisor is -n, but this value is outside the
+  // range of possible values, so we take the next best alternative.
+  if (n == 0 || n == std::numeric_limits<T>::min()) {
     return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
   }
   return (n & (~(n - 1)));
diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp
index 52082ddf7021..0773d58a740b 100644
--- a/lib/Analysis/Alias.cpp
+++ b/lib/Analysis/Alias.cpp
@@ -49,11 +49,13 @@ void SharedMemoryAliasAnalysis::visitOperation(
   }
 
   if (pessimistic) {
-    return setAllToEntryStates(results);
+    setAllToEntryStates(results);
+    return;
   }
   // Join all lattice elements
   for (auto *result : results)
     propagateIfChanged(result, result->join(aliasInfo));
+  return;
 }
 
 AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp
index 1f3368d0d481..720737243b9b 100644
--- a/lib/Analysis/AxisInfo.cpp
+++ b/lib/Analysis/AxisInfo.cpp
@@ -1048,8 +1048,10 @@ void AxisInfoAnalysis::visitOperation(
     if (op->getValue().getRank() == 0)
       setToEntryState((dataflow::Lattice<AxisInfo> *)op);
   AxisInfo curr = visitors.apply(op, operands);
-  if (curr.getRank() == 0)
-    return setAllToEntryStates(results);
+  if (curr.getRank() == 0) {
+    setAllToEntryStates(results);
+    return;
+  }
   // override with hint
   auto newContiguity = curr.getContiguity();
   auto newDivisibility = curr.getDivisibility();
@@ -1071,6 +1073,7 @@ void AxisInfoAnalysis::visitOperation(
   // join all lattice elements
   for (auto *result : results)
     propagateIfChanged(result, result->join(curr));
+  return;
 }
 
 void AxisInfoAnalysis::visitForOpInductionVar(
diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp
index 933f062d8191..0d3d9fd35b81 100644
--- a/lib/Analysis/Utility.cpp
+++ b/lib/Analysis/Utility.cpp
@@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) {
   if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
     return false;
 
+  auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
   auto F8E5M2 = TypeID::get<Float8E5M2Type>();
   auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
   auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
@@ -436,6 +437,7 @@ bool supportMFMATypes(Type a, Type b) {
       {F32, F32},
       {F16, F16},
       {BF16, BF16},
+      {F8E4M3FN, F8E4M3FN},
       {F8E5M2, F8E5M2},
       {F8E4M3FNUZ, F8E4M3FNUZ},
       {F8E4M3FNUZ, F8E5M2FNUZ},
@@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) {
       return false;
     if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
           retShapePerCTA[rank - 1] % 8 == 0 &&
-          (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
+          (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
            aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
            aElemTy.isF32()))) {
       return false;
     }
     // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
     if (op.getMaxNumImpreciseAcc() < 32 &&
-        (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
+        (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
         cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
       return false;
     }
diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
index 0287207be51a..7aca67d97e90 100644
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
   auto ouEltTy = ouTensorTy.getElementType();
   if (inBitWidth == ouBitWidth)
     return values;
-  if (inBitWidth == 16 && ouBitWidth == 32) {
+  if ((inBitWidth == 16 && ouBitWidth == 32) ||
+      (inBitWidth == 32 && ouBitWidth == 16)) {
     SmallVector<Value> ret;
     for (unsigned i = 0; i < values.size(); i += 8) {
       ret.push_back(values[i]);
diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
index 53705c3b78b9..c0371cfe1d1b 100644
--- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
@@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
   addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
     return IntegerType::get(type.getContext(), 8);
   });
+  addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
+    return IntegerType::get(type.getContext(), 8);
+  });
   addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
     return IntegerType::get(type.getContext(), 8);
   });
diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
index d0988be3bc95..dab9f3fdf068 100644
--- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
@@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
     // LLVM IR.
     if (type::isFloat8(elemType))
       elemType = rewriter.getIntegerType(8);
-    auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
     auto typeConverter = getTypeConverter();
+    auto constOp = rewriter.create<LLVM::ConstantOp>(
+        loc, typeConverter->convertType(elemType), val);
     auto llStruct = SplatOpConversion::convertSplatLikeOp(
         elemType, op.getType(), constOp, typeConverter, rewriter, loc);
     rewriter.replaceOp(op, llStruct);
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index d5b5d459a910..6620e92d3c9b 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -2721,6 +2721,11 @@ struct CanonicalizeConvertFromAlloc
     auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
     if (!convert)
       return failure();
+    // LocalAllocOp lowering doesn't support going from DotOperandEncoding
+    // to SharedEncoding, so we want to keep this layout conversion.
+    if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
+            convert.getSrc().getType().getEncoding()))
+      return failure();
     rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
         op, op->getResult(0).getType(), convert.getSrc());
     return mlir::success();
diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index 39c043695bc6..72f73682bd00 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
   auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
                                   newLayout, SharedMemorySpace);
   rewriter.setInsertionPointAfterValue(arg);
+
+  // LocalAllocOp lowering doesn't support going from DotOperandEncoding
+  // to SharedEncoding.
+  if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
+          argType.getEncoding())) {
+    // Create a layout conversion from DotOperandEncoding to BlockedEncoding
+    // then pass it to the LocalAllocOp.
+    auto newArgType = RankedTensorType::get(
+        argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
+    auto dotOperandToBlockedCvt =
+        rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
+    return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
+                                              dotOperandToBlockedCvt);
+  }
+
   return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
 }
 
@@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
   mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
 
   static bool bwdFilter(Operation *op) {
+    // Dot operand layout assignment to Predicates are not currently supported
+    // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
+    // condition limits visibility of the original bit-width so that predicate
+    // are not considered, hence, kwidth can never be = 32.
+    if (isa<arith::UIToFPOp>(op)) {
+      Type srcType = getElementTypeOrSelf(op->getOperand(0));
+      if (srcType.isInteger(1))
+        return false;
+    }
     return op->getNumOperands() == 1 &&
            (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
             isPureUnaryInlineAsm(op) ||
@@ -357,7 +381,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
     NvidiaMmaEncodingAttr mmaLayout =
         dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
     if (mmaLayout) {
-      bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
+      bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
       // promote operands for sm < 89 since fp8 mma is not natively supported
       // promote operands for sm >= 90 when mma is not v3
       if (!isNativeFP8 ||
diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
index 6d8279795209..e6e0ec8d7cef 100644
--- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
@@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
                                 PatternRewriter &rewriter) const override {
     // Only consider conversions to dot operand.
     auto cvtTy = cast<RankedTensorType>(cvt.getType());
-    if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
+    auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
+    if (!dotOpEnc)
       return failure();
 
     auto src = cvt.getSrc().getDefiningOp();
@@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
                 [](Type ty) { return isa<RankedTensorType>(ty); }))
       return failure();
 
+    // Quick handling to fix loading issues when computing the original
+    // bitwidth is unable to realize that there is a mixed-precision dot
+    // (hence kWidth = 1) but wants to hoist through the type conversion.
+    if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
+        return failure();
+
     // Only consider custom conversions or arith ops.
     // TODO(jlebar): Is this too restrictive?
     if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
@@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
     if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
       return failure();
 
+    // Don't hoist through u1 -> fp casts as they aren't supported in
+    // ElementwiseOpToLLVM::reorderValues().
+    if (isa<arith::UIToFPOp>(src)) {
+      Type srcType = getElementTypeOrSelf(src->getOperand(0));
+      if (srcType.isInteger(1))
+        return failure();
+    }
+
     // Check that the conversion is transitively dependent on a load, and all
     // operations between the load and the conversion are layout preserving.
     //
diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
index 02994e1ac059..cd6fc806928d 100644
--- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
                                type.getMemorySpace()),
       v, offsetsVal);
 
+  // We need to assign kwidth to zero in the case where the parent layout is
+  // Blocked, otherwise the verifier emits a failure. The parent layout is
+  // Blocked only when Tensor Cores are disabled.
+  int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
+                   ? 0
+                   : prefetchWidth / 8;
   auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
-      builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
+      builder.getContext(), opIdx, dotEncoding, kwidth);
   Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
       v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
       newSmem);
@@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() {
         break;
       if (!op->getResult(0).hasOneUse())
         break;
+      // Similar to issues faced in HoistLayoutConversion pattern in
+      // OptimizeDotOperands.cpp, we can't propagate through type casts from
+      // predicates as they aren't supported in Triton when encoded with dot_op
+      // layout.
+      if (isa<arith::UIToFPOp>(op)) {
+        Type srcType = getElementTypeOrSelf(op->getOperand(0));
+        if (srcType.isInteger(1))
+          break;
+      }
       rets.push_back(op->getOperand(0));
       if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
         foundConvertFromShared = true;
diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
index b6d855a05388..984b8e6eb058 100644
--- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
@@ -138,11 +138,6 @@ class LayoutRematerialization {
                     ConvertLayoutOp convertOp);
 
 private:
-  void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
-  // Existing tuples of (value, layout) that needs to be updated when recreating
-  // scf ops. This prevents keeping track of Values that have been delete when
-  // rewriting slices.
-  DenseMap<Value, Attribute> mappedValues;
   // map of the values remat based on encoding.
   DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
   // DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -154,7 +149,6 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
                                             Value newV) {
   LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
   rematMapping[{old, encoding}] = newV;
-  mappedValues[old] = encoding;
 }
 
 // Remove unneeded values now that we are done with the rematMapping.
@@ -813,31 +807,6 @@ bool canBeRemat(Operation *op) {
   return true;
 }
 
-void LayoutRematerialization::updateRematMapping(
-    SmallVector<std::tuple<Value, Value>> &values) {
-  for (auto [old, newV] : values) {
-    auto it = mappedValues.find(old);
-    if (it != mappedValues.end()) {
-      Attribute encoding = it->second;
-      auto rematIt = rematMapping.find({old, it->second});
-      assert(rematIt != rematMapping.end());
-      Value replacedValue = rematIt->second;
-      rematMapping.erase(rematIt);
-      mappedValues.erase(it);
-      // Loop through the replacement value to find the new version of remat
-      // value. This should be okay as the number of values should be small.
-      for (auto [before, after] : values) {
-        if (before == replacedValue) {
-          replacedValue = after;
-          break;
-        }
-      }
-      rematMapping[{newV, encoding}] = replacedValue;
-      mappedValues[newV] = encoding;
-    }
-  }
-}
-
 void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
                                            DenseMap<Value, Attribute> &layout,
                                            ConvertLayoutOp convertOp,
@@ -850,14 +819,6 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
   // for/yield to fall out of sync
   SetVector<Value> valuesWithExistingRemat;
   for (Value v : slice) {
-    auto layoutIt = layout.find(v);
-    assert(layoutIt != layout.end());
-    // If we already have a remat value for this value, use it.
-    if (hasRematValue(v, layoutIt->second)) {
-      mapping.map(v, getRematValue(v, layoutIt->second));
-      valuesWithExistingRemat.insert(v);
-      continue;
-    }
     if (v.getDefiningOp()) {
       opsToRewrite.insert(v.getDefiningOp());
       if (auto ifOp = v.getDefiningOp<scf::IfOp>()) {
@@ -947,8 +908,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
         if (slice.count(res)) {
           // Why can't we use res instead of ifOp.getResult(oldIdx)?
           mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx));
-          addRematValue(ifOp.getResult(oldIdx), layout[res],
-                        newIfOp.getResult(newIdx));
+          addRematValue(res, layout[res], newIfOp.getResult(newIdx));
           ++newIdx;
         }
         ++oldIdx;
@@ -979,8 +939,6 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
       auto cvt = builder.create<ConvertLayoutOp>(op->getLoc(), newType,
                                                  newOp->getResult(0));
       mapping.map(op->getResult(0), cvt.getResult());
-      addRematValue(op->getResult(0), layout[op->getResult(0)],
-                    cvt.getResult());
       continue;
     }
     Operation *newOp = builder.clone(*op, mapping);
@@ -992,14 +950,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
           cast<RankedTensorType>(old.getType()).getShape(),
           cast<RankedTensorType>(old.getType()).getElementType(), it->second);
       newV.setType(newType);
-      addRematValue(old, it->second, newV);
     }
   }
   // Check mapping and see if there are existing convertOps on the old Argument
   convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc()));
   opToDelete.insert(convertOp);
 
-  updateRematMapping(replacements);
   for (auto &kv : replacements) {
     builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
   }
diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
index 57e41e55ff4f..6f4480d6fe93 100644
--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
@@ -45,8 +45,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
     SmallVector<unsigned> validN;
 
     // MMAv3 with larger instruction shape is preferred.
-    if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() ||
-        eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
+    if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
+        eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
+        eltType.isF32()) {
       validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
                      168, 160, 152, 144, 136, 128, 120, 112, 104, 96,  88,
                      80,  72,  64,  56,  48,  40,  32,  24,  16,  8});
diff --git a/python/BUILD b/python/BUILD
new file mode 100644
index 000000000000..334dd4aec41a
--- /dev/null
+++ b/python/BUILD
@@ -0,0 +1,77 @@
+# NOTE: Do not depend on any targets from this directory,
+# but use //third_party/py/triton instead.
+
+load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
+
+package(
+    default_applicable_licenses = ["//:license"],
+    default_visibility = [
+        "//third_party/py/triton:__pkg__",
+        "@triton//python:__subpackages__",
+    ],
+)
+
+cc_library(
+    name = "passes",
+    hdrs = ["src/passes.h"],
+    includes = ["src"],
+    visibility = ["@triton//third_party:__subpackages__"],
+)
+
+pybind_extension(
+    name = "libtriton",
+    srcs = [
+        "src/interpreter.cc",
+        "src/ir.cc",
+        "src/llvm.cc",
+        "src/main.cc",
+        "src/passes.cc",
+    ],
+    copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
+    deps = [
+        ":passes",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:IPO",
+        "@llvm-project//llvm:IRReader",
+        "@llvm-project//llvm:InstCombine",
+        "@llvm-project//llvm:Linker",
+        "@llvm-project//llvm:MC",
+        "@llvm-project//llvm:Passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:Target",
+        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
+        "@llvm-project//mlir:BytecodeWriter",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:ConversionPasses",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:IndexDialect",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:LLVMIRTransforms",
+        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
+        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:ToLLVMIRTranslation",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonAnalysis",
+        "//:TritonDialects",
+        "//:TritonGPUToLLVM",
+        "//:TritonGPUTransforms",
+        "//:TritonHSACO",
+        "//:TritonLLVMIR",
+        "//:TritonNvidiaGPUTransforms",
+        "//:TritonPTX",
+        "//:TritonToTritonGPU",
+        "//:TritonTools",
+        "//:TritonTransforms",
+        "@triton//third_party/nvidia:triton_nvidia",
+    ],
+)
+
+filegroup(
+    name = "files",
+    srcs = glob(
+        include = ["triton/**/*.py"],
+    ),
+)
diff --git a/python/src/ir.cc b/python/src/ir.cc
index 46095dcc6653..7575cf87a5de 100644
--- a/python/src/ir.cc
+++ b/python/src/ir.cc
@@ -1,4 +1,4 @@
-#include <pybind11/functional.h>
+#include <pybind11/functional.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 
@@ -8,6 +8,7 @@
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
@@ -224,6 +225,7 @@ void init_triton_ir(py::module &&m) {
                     math::MathDialect, arith::ArithDialect, index::IndexDialect,
                     scf::SCFDialect, ::mlir::gpu::GPUDialect,
                     cf::ControlFlowDialect, LLVM::LLVMDialect>();
+    mlir::LLVM::registerInlinerInterface(registry);
     registerBuiltinDialectTranslation(registry);
     registerLLVMDialectTranslation(registry);
     context.appendDialectRegistry(registry);
@@ -745,10 +747,8 @@ void init_triton_ir(py::module &&m) {
              return self.getBuilder().getI64Type();
            })
       .def("get_fp8e4nv_ty",
-           // TODO: fp8e4nv is using Float8E4M3FNUZType, which
-           // does not seem right. It should use FloatE4M3FNType
            [](TritonOpBuilder &self) -> Type {
-             return self.getBuilder().getType<Float8E4M3FNUZType>();
+             return self.getBuilder().getType<Float8E4M3FNType>();
            })
       .def("get_fp8e4b8_ty",
            [](TritonOpBuilder &self) -> Type {
diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD
new file mode 100644
index 000000000000..a88f4eeae1f8
--- /dev/null
+++ b/python/test/regression/BUILD
@@ -0,0 +1,26 @@
+load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")
+
+package(
+    default_applicable_licenses = ["//:license"],
+)
+
+pytest_multi_tests(
+    name = "tests",
+    size = "large",
+    srcs = ["conftest.py"],
+    shard_count = 10,
+    tags = [
+        "config-cuda-only",
+        "requires-gpu-sm80",
+    ],
+    tests = glob(
+        include = ["test_*.py"],
+        exclude = [
+            "test_performance.py",  #TODO(b/321005767): fix failing test
+        ],
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py
new file mode 100644
index 000000000000..7a02d322b49f
--- /dev/null
+++ b/python/test/regression/conftest.py
@@ -0,0 +1,12 @@
+# content of conftest.py
+
+import pytest
+
+
+def pytest_addoption(parser):
+    parser.addoption("--device", action="store", default='cuda')
+
+
+@pytest.fixture
+def device(request):
+    return request.config.getoption("--device")
diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD
new file mode 100644
index 000000000000..6997c23e0635
--- /dev/null
+++ b/python/test/unit/BUILD
@@ -0,0 +1,179 @@
+load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test")
+
+package(
+    default_applicable_licenses = ["//:license"],
+)
+
+_requires_gpu_sm80 = [
+    "config-cuda-only",
+    "requires-gpu-sm80",
+]
+
+_requires_config_cuda = select(
+    {"@local_config_cuda//cuda:using_clang_allow_exec": []},
+    no_match_error = "Requires --config=cuda",
+)
+
+EXCLUDE_TESTS = [
+    "language/test_pipeliner_h100.py",  # TODO(b/362458006): fix failing test
+    "language/test_reproducer.py",  # this is not an actual test, but a tool for running reproducers
+    "language/test_subprocess.py",  # TODO(b/320224484): fix failing test
+    "runtime/test_launch.py",  # TODO(b/320226169): fix failing tests
+    "tools/test_aot.py",  # TODO(b/320224484): fix failing test
+    "tools/test_disasm.py",  # TODO(b/320224484): fix failing test
+    "hopper/test_experimental_tma_h100.py",  # TODO(b/362458006): fix failing test
+    "hopper/test_persistent_warp_specialized_gemm.py",  # TODO (b/342348738): fix failing test
+    "hopper/test_tma_descriptor.py",  # TODO (b/358060133): fix failing test
+    "runtime/test_cublas.py",  # TODO(b/346755023): fix failing test
+]
+
+# Runs all python tests on H100
+pytest_multi_tests(
+    name = "hopper",
+    size = "large",
+    srcs = [
+        "conftest.py",
+        "language/conftest.py",
+        "language/test_core.py",
+    ],
+    name_suffix = "_h100",
+    shard_count = 10,
+    tags = [
+        "config-cuda-only",
+        "requires-gpu-sm90",
+    ],
+    target_compatible_with = _requires_config_cuda,
+    tests = glob(
+        include = ["**/test_*.py"],
+        exclude = EXCLUDE_TESTS + ["language/test_core.py"],
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+# Shard test_core more, as it is otherwise very slow to run.
+pytest_test(
+    name = "hopper/language/test_core_h100",
+    size = "large",
+    srcs = [
+        "conftest.py",
+        "language/conftest.py",
+    ],
+    shard_count = 40,
+    tags = [
+        "config-cuda-only",
+        "requires-gpu-sm90",
+    ],
+    target_compatible_with = _requires_config_cuda,
+    tests = ["language/test_core.py"],
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+pytest_multi_tests(
+    name = "language",
+    size = "large",
+    srcs = [
+        "conftest.py",
+        "language/conftest.py",
+        "language/test_core.py",
+    ],
+    shard_count = 10,
+    tags = _requires_gpu_sm80,
+    target_compatible_with = _requires_config_cuda,
+    tests = glob(
+        include = ["language/**/test_*.py"],
+        exclude = EXCLUDE_TESTS + ["language/test_core.py"],
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+# Shard test_core more, as it is otherwise very slow to run.
+pytest_test(
+    name = "language/test_core",
+    size = "large",
+    srcs = [
+        "conftest.py",
+        "language/conftest.py",
+    ],
+    shard_count = 40,
+    tags = _requires_gpu_sm80,
+    target_compatible_with = _requires_config_cuda,
+    tests = ["language/test_core.py"],
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+pytest_multi_tests(
+    name = "instrumentation",
+    size = "large",
+    srcs = ["conftest.py"],
+    shard_count = 10,
+    tags = _requires_gpu_sm80,
+    target_compatible_with = _requires_config_cuda,
+    tests = glob(
+        include = ["instrumentation/**/test_*.py"],
+        exclude = EXCLUDE_TESTS,
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+pytest_multi_tests(
+    name = "runtime",
+    srcs = ["conftest.py"],
+    tags = _requires_gpu_sm80,
+    target_compatible_with = _requires_config_cuda,
+    tests = glob(
+        include = ["runtime/**/test_*.py"],
+        exclude = EXCLUDE_TESTS,
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+pytest_multi_tests(
+    name = "tools",
+    size = "large",
+    shard_count = 10,
+    tags = _requires_gpu_sm80,
+    target_compatible_with = _requires_config_cuda,
+    tests = glob(
+        include = ["tools/**/test_*.py"],
+        exclude = EXCLUDE_TESTS,
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
+
+pytest_multi_tests(
+    name = "unit",
+    size = "large",
+    srcs = ["conftest.py"],
+    shard_count = 10,
+    tags = _requires_gpu_sm80,
+    target_compatible_with = _requires_config_cuda,
+    tests = glob(
+        include = ["test_*.py"],
+        exclude = EXCLUDE_TESTS,
+    ),
+    deps = [
+        "//third_party/py/torch:pytorch",
+        "//third_party/py/triton",
+    ],
+)
diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py
index 9e5ff8a2ce37..60ad26871a53 100644
--- a/python/test/unit/language/test_core.py
+++ b/python/test/unit/language/test_core.py
@@ -2144,6 +2144,8 @@ def kernel(X, Z, BLOCK: tl.constexpr):
 reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]]
 
 
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] >= 9,
+                    reason='Reduction test produces wrong results on H100, b/342347027')
 @pytest.mark.interpreter
 @pytest.mark.parametrize(
     "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config +
@@ -3637,6 +3639,25 @@ def _kernel(out):
     kernel[(1, )](out)
     assert torch.all(out == out_ref)
 
+@pytest.mark.interpreter
+def test_dot_on_broadcast(device):
+    @triton.jit
+    def _kernel(a, b, out):
+        a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :]
+        lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64)
+        rhs = tl.load(b)
+        rhs_bc = tl.broadcast_to(rhs, [32, 32])
+        c = tl.dot(lhs, rhs_bc)
+        out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :]
+        tl.store(out_ptr, c)
+
+    a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device)
+    b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device)
+    out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32)))
+    out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device)
+    _kernel[(1, )](a, b, out, num_stages=1, num_warps=4)
+    assert torch.all(out == out_ref)
+
 
 # ---------------
 # test arange
diff --git a/python/triton/_C/include b/python/triton/_C/include
index b85a409837d1..8a5dba6c4b56 120000
--- a/python/triton/_C/include
+++ b/python/triton/_C/include
@@ -1 +1 @@
-../../../include/
\ No newline at end of file
+../../../include
\ No newline at end of file
diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py
index 92ba144ba97b..f9bab523bf6c 100644
--- a/python/triton/backends/__init__.py
+++ b/python/triton/backends/__init__.py
@@ -46,5 +46,8 @@ def _discover_backends():
                                  _find_concrete_subclasses(driver, DriverBase))
     return backends
 
-
-backends = _discover_backends()
+from triton.backends.nvidia.driver import CudaDriver
+from triton.backends.nvidia.compiler import CUDABackend
+backends = {
+        "nvidia": Backend(CUDABackend, CudaDriver)
+}
diff --git a/test/BUILD b/test/BUILD
new file mode 100644
index 000000000000..ca98d32da7b0
--- /dev/null
+++ b/test/BUILD
@@ -0,0 +1,63 @@
+# copybara:uncomment_begin
+# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests")
+# load("//tools/build_defs/build_test:build_test.bzl", "build_test")
+# 
+# package(
+#     default_applicable_licenses = ["//:license"],
+#     default_compatible_with = ["//buildenv/target:gce"],
+#     default_visibility = ["//:__subpackages__"],
+# )
+# 
+# glob_lit_tests(
+#     name = "all_tests",
+#     data = [
+#         "@llvm-project//llvm:FileCheck",
+#         "//:triton-llvm-opt",
+#         "//:triton-opt",
+#         "//:triton-tensor-layout",
+#     ],
+#     driver = "@llvm-project//mlir:run_lit.sh",
+#     exclude = [
+#         "Conversion/amd/dedup-by-constancy.mlir",  # AMD-specific, broken
+#         "TritonGPU/combine.mlir",  # TODO: b/338346821 - needs cse or something.
+#         "TritonGPU/dot-operands.mlir",  # TODO: b/283035396 - broken by cl536931041.patch
+#         "TritonGPU/optimize_epilogue.mlir",  # TODO: b/346283526 - AMD-specific, triggering UBSAN
+#     ],
+#     test_file_exts = [
+#         "mlir",
+#         "ll",
+#     ],
+# )
+# 
+# build_test(
+#     name = "build_test",
+#     allow_empty_target = False,
+#     targets = [
+#         "//:TritonAnalysis",
+#         "//:TritonDialects",
+#         "//:TritonGPUToLLVM",
+#         "//:TritonGPUTransforms",
+#         "//:TritonLLVMIR",
+#         "//:TritonPTX",
+#         "//:TritonToTritonGPU",
+#         "//:TritonTools",
+#         "//:TritonTransforms",
+#         "//:triton-opt",
+#     ],
+# )
+# copybara:uncomment_end
+
+cc_library(
+    name = "TritonTestAnalysis",
+    srcs = glob(["lib/Analysis/*.cpp"]),
+    deps = [
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFToControlFlow",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonAnalysis",
+        "//:TritonDialects",
+    ],
+)
diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir
index 7ecee2eba11b..3c16ea0260ef 100644
--- a/test/Conversion/tritongpu_to_llvm_hopper.mlir
+++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir
@@ -129,24 +129,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
 module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
   // CHECK-LABEL: test_fp8_to_f16_conversion
   tt.func @test_fp8_to_f16_conversion(
-    %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>,
+    %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>,
     %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) {
     // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
     %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
     // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
-    %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
+    %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked>
     // CHECK-COUNT-2: mul.rn.bf16x2
     %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked>
 
     // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
     %out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
     // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
-    %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
+    %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked>
 
     // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
     %out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
     // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
-    %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
+    %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked>
     tt.return
   }
 }
diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir
index 923324616940..50f39c1bf970 100644
--- a/test/TritonGPU/accelerate-matmul.mlir
+++ b/test/TritonGPU/accelerate-matmul.mlir
@@ -142,3 +142,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
     tt.return
   }
 }
+
+// -----
+
+// CHECK-DAG: #[[$BLOCKED:.*]] = #triton_gpu.blocked
+// CHECK-DAG: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3
+#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
+  tt.func @local_alloc_dot_operand(%in0:  tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) {
+    // CHECK-LABEL: local_alloc_dot_operand
+    %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
+    // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc
+    // CHECK: %[[RHS_CVT:.*]] = triton_gpu.convert_layout {{.*}} #triton_gpu.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]]
+    // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc %[[RHS_CVT]]
+    // CHECK: triton_nvidia_gpu.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]]
+    %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked>
+    tt.return %res :  tensor<64x32xf32, #blocked>
+  }
+}
diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir
index ecee359cb19a..f015f9651065 100644
--- a/test/TritonGPU/canonicalize.mlir
+++ b/test/TritonGPU/canonicalize.mlir
@@ -133,3 +133,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
     tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
   }
 }  // end module
+
+// -----
+
+// CHECK: #[[$BLOCKED:.*]] = #triton_gpu.blocked
+#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
+#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
+  tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> !tt.memdesc<256x32xf32, #shared1> {
+    // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized
+    %cvt_in = triton_gpu.convert_layout %in : tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked>
+    %alloc = triton_gpu.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !tt.memdesc<256x32xf32, #shared1>
+    // CHECK: %[[ALLOC:.*]] = triton_gpu.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) ->
+    tt.return %alloc : !tt.memdesc<256x32xf32, #shared1>
+  }
+} // end module
+
diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir
index cf93c37b840d..16550b98259c 100644
--- a/test/TritonGPU/coalesce.mlir
+++ b/test/TritonGPU/coalesce.mlir
@@ -131,3 +131,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
     tt.return
   }
 }
+
+// -----
+
+
+module {
+  tt.func @int_min_does_not_underflow_in_analysis() -> i64 {
+    %int_min = arith.constant -9223372036854775808 : i64
+    tt.return %int_min : i64
+  }
+}
diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir
index 1e47e1449f0f..04281fe64d14 100644
--- a/test/TritonGPU/combine.mlir
+++ b/test/TritonGPU/combine.mlir
@@ -47,11 +47,14 @@ tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
   %5 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1>
   %6 = arith.addi %3, %5 : tensor<1024xi32, #layout1>
   tt.return %6: tensor<1024xi32, #layout1>
-  // CHECK: %[[A:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
-  // CHECK: %[[B:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
-  // CHECK: %[[C:.+]] = arith.muli %[[A]], %[[B]] : tensor<1024xi32, [[$target_layout]]>
-  // CHECK: %[[D:.+]] = arith.addi %[[C]], %[[C]] : tensor<1024xi32, [[$target_layout]]>
-  // CHECK: tt.return %[[D]] : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]>
+  // CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]>
 }
 
 // Always rematerialize single value loads
diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir
index 3a9e80c04b0a..848e2b88f342 100644
--- a/test/TritonGPU/prefetch.mlir
+++ b/test/TritonGPU/prefetch.mlir
@@ -173,3 +173,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
     tt.return
   }
 }
+
+// -----
+
+// CHECK: tt.func @matmul_loop_on_blocked_layout
+#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
+#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
+  tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !tt.memdesc<16x512xf32, #shared, mutable>, %arg_rhs: !tt.memdesc<512x32xf32, #shared, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) {
+    %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable>) : i32 {
+      %lhs_ll = triton_gpu.local_load %lhs : !tt.memdesc<16x512xf32, #shared, mutable> -> tensor<16x512xf32, #blocked>
+      %lhs_ll_cvt = triton_gpu.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
+      %rhs_ll = triton_gpu.local_load %rhs : !tt.memdesc<512x32xf32, #shared, mutable> -> tensor<512x32xf32, #blocked>
+      %rhs_ll_cvt = triton_gpu.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
+      %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked>
+      scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable>
+    }
+    tt.return %loop#0 : tensor<16x32xf32, #blocked>
+  }
+} // end module
diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD
new file mode 100644
index 000000000000..d32c944f37ed
--- /dev/null
+++ b/third_party/amd/BUILD
@@ -0,0 +1,142 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+
+package(
+    # copybara:uncomment_begin
+    # default_applicable_licenses = ["//:license"],
+    # default_compatible_with = ["//buildenv/target:gce"],
+    # default_visibility = [
+        # "//third_party/tensorflow/compiler/xla/service/gpu/fusions/triton:__subpackages__",
+        # "//:__subpackages__",
+    # ],
+    # copybara:uncomment_end_and_comment_begin
+    default_visibility = ["//visibility:public"],
+    # copybara:comment_end
+)
+
+# TODO(csigg): fix, enable error upstream, remove.
+_no_unused_variable = select({
+    "//:compiler_is_msvc": [],
+    "//conditions:default": ["-Wno-unused-variable"],
+})
+
+cc_library(
+    name = "TritonAMDGPUTransforms",
+    srcs = glob([
+        "lib/TritonAMDGPUTransforms/**/*.h",
+        "lib/TritonAMDGPUTransforms/**/*.cpp",
+    ]) + ["include/TritonAMDGPUToLLVM/TargetUtils.h"],
+    hdrs = glob([
+        "include/TritonAMDGPUTransforms/**/*.h",
+    ]),
+    copts = _no_unused_variable,
+    includes = [
+        "include",
+        "lib/TritonAMDGPUTransforms",
+    ],
+    deps = [
+        ":triton_conversion_amdgpu_transforms_passes_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TargetParser",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ConvertToLLVM",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InferTypeOpInterface",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:ROCDLDialect",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:SideEffectInterfaces",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonAnalysis",
+        "//:TritonDialects",
+        "//:TritonGPUToLLVM",
+        "//:TritonGPUTransforms",
+    ],
+)
+
+cc_library(
+    name = "TritonAMDGPUToLLVM",
+    srcs = glob([
+        "lib/TritonAMDGPUToLLVM/**/*.h",
+        "lib/TritonAMDGPUToLLVM/**/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/TritonAMDGPUToLLVM/**/*.h",
+    ]),
+    copts = _no_unused_variable,
+    includes = [
+        "include",
+        "lib/TritonAMDGPUToLLVM",
+    ],
+    deps = [
+        ":TritonAMDGPUTransforms",
+        ":triton_conversion_amdgpu_to_llvm_passes_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TargetParser",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ArithToLLVM",
+        "@llvm-project//mlir:ControlFlowToLLVM",
+        "@llvm-project//mlir:ConvertToLLVM",
+        "@llvm-project//mlir:GPUToNVVMTransforms",
+        "@llvm-project//mlir:GPUToROCDLTransforms",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:IndexDialect",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MathToLLVM",
+        "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:ROCDLDialect",
+        "@llvm-project//mlir:SCFToControlFlow",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonAnalysis",
+        "//:TritonDialects",
+        "//:TritonGPUToLLVM",
+    ],
+)
+
+td_library(
+    name = "td_files",
+    srcs = glob(["include/**/*.td"]),
+    includes = ["include"],
+    deps = ["//:td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonAMDGPUToLLVM",
+            ],
+            "include/TritonAMDGPUToLLVM/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/TritonAMDGPUToLLVM/Passes.td",
+    deps = [":td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_conversion_amdgpu_transforms_passes_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonAMDGPU",
+            ],
+            "include/TritonAMDGPUTransforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/TritonAMDGPUTransforms/Passes.td",
+    deps = [":td_files"],
+)
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
index 99dad006dba3..3c429a3d724a 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
@@ -128,7 +128,7 @@ class CallOpConversion : public mlir::RewritePattern {
     auto operands = callOp.getOperands();
     auto result = callOp.getResult();
 
-    LLVM::LLVMFunctionType calleeType = callOp.getCalleeType().value();
+    LLVM::LLVMFunctionType calleeType = callOp.getCalleeFunctionType();
     Type returnType = calleeType.getReturnType();
 
     auto loc = callOp.getLoc();
diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc
index a6ef2fec7c67..598df0fba7db 100644
--- a/third_party/amd/python/triton_amd.cc
+++ b/third_party/amd/python/triton_amd.cc
@@ -195,9 +195,7 @@ void init_triton_amd(py::module &&m) {
             target->createMCAsmBackend(*sti, *mri, mcOptions));
         mcStreamer.reset(target->createMCObjectStreamer(
             triple, ctx, std::move(mab), mab->createObjectWriter(svos),
-            std::move(ce), *sti, mcOptions.MCRelaxAll,
-            mcOptions.MCIncrementalLinkerCompatible,
-            /*DWARFMustBeAtTheEnd=*/false));
+            std::move(ce), *sti));
 
         std::unique_ptr<llvm::MCAsmParser> parser(
             createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD
new file mode 100644
index 000000000000..a1a4f3a7ca02
--- /dev/null
+++ b/third_party/f2reduce/BUILD
@@ -0,0 +1,31 @@
+# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license")
+
+package(
+    # copybara:uncomment_begin
+    # default_applicable_licenses = ["//:license"],
+    # default_compatible_with = ["//buildenv/target:gce"],
+    # default_visibility = [
+        # "//:__subpackages__",
+    # ],
+    # copybara:uncomment_end_and_comment_begin
+    default_visibility = ["//visibility:public"],
+    # copybara:comment_end
+)
+
+# copybara:uncomment_begin
+# license(
+#     name = "license",
+#     license_text = "LICENCE.txt",
+# )
+# 
+# licenses(["notice"])
+# 
+# exports_files(["LICENCE.txt"])
+# copybara:uncomment_end
+
+cc_library(
+    name = "f2reduce",
+    srcs = ["f2reduce.cpp"],
+    hdrs = ["f2reduce.h"],
+    # copybara:uncomment strip_include_prefix = "/third_party/triton",
+)
diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD
new file mode 100644
index 000000000000..6af127c11ec6
--- /dev/null
+++ b/third_party/nvidia/BUILD
@@ -0,0 +1,306 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("@pybind11_bazel//:build_defs.bzl", "pybind_library")
+
+package(
+    # copybara:uncomment_begin
+    # default_applicable_licenses = ["//:license"],
+    # default_compatible_with = ["//buildenv/target:gce"],
+    # default_visibility = [
+        # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__",
+        # "//:__subpackages__",
+    # ],
+    # copybara:uncomment_end_and_comment_begin
+    default_visibility = ["//visibility:public"],
+    # copybara:comment_end
+)
+
+pybind_library(
+    name = "cublas_headers",
+    hdrs = glob([
+        "include/*.h",
+    ]),
+    deps = ["@local_config_cuda//cuda:cuda_headers"],
+)
+
+pybind_library(
+    name = "triton_nvidia",
+    srcs = [
+        "triton_nvidia.cc",
+    ],
+    compatible_with = [],
+    # copybara:uncomment_begin
+    # visibility = [
+        # "@triton//python:__subpackages__",
+    # ],
+    # copybara:uncomment_end
+    deps = [
+        ":NVGPUDialect",
+        ":NVGPUToLLVM",
+        ":TritonNVIDIAGPUToLLVM",
+        ":cublas_headers",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonDialects",
+        "//:TritonGPUToLLVM",
+        "//:TritonNvidiaGPUTransforms",
+        "@triton//python:passes",
+    ],
+)
+
+cc_library(
+    name = "NVGPUToLLVM",
+    srcs = glob([
+        "lib/NVGPUToLLVM/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/NVGPUToLLVM/*.h",
+    ]),
+    # copybara:uncomment_begin
+    # compatible_with = ["//buildenv/target:gce"],
+    # copybara:uncomment_end
+    copts = select({
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    includes = [
+        "..",
+        "include",
+    ],
+    deps = [
+        ":NVGPUDialect",
+        ":TritonNVIDIAGPUToLLVM",
+        ":triton_conversion_nvgpu_to_llvm_passes_inc_gen",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonDialects",
+    ],
+)
+
+cc_library(
+    name = "TritonNVIDIAGPUToLLVM",
+    srcs = glob([
+        "lib/TritonNVIDIAGPUToLLVM/*.h",
+        "lib/TritonNVIDIAGPUToLLVM/**/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/TritonNVIDIAGPUToLLVM/*.h",
+    ]) + [
+        "lib/TritonNVIDIAGPUToLLVM/Utility.h",
+    ],
+    # copybara:uncomment_begin
+    # compatible_with = ["//buildenv/target:gce"],
+    # copybara:uncomment_end
+    copts = select({
+        "//conditions:default": [
+            "-Wno-reorder-ctor",
+            "-Wno-unused-variable",
+        ],
+    }),
+    includes = [
+        "..",
+        "include",
+        "lib/TritonNVIDIAGPUToLLVM",
+    ],
+    deps = [
+        ":NVGPUDialect",
+        ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ArithToLLVM",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:ControlFlowToLLVM",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:GPUToNVVMTransforms",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:IndexDialect",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MathToLLVM",
+        "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFToControlFlow",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+        "//:TritonAnalysis",
+        "//:TritonDialects",
+        "//:TritonGPUToLLVM",
+        "//:triton_gpu_attr_inc_gen",
+    ],
+)
+
+gentbl_cc_library(
+    name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen",
+    # copybara:uncomment_begin
+    # compatible_with = ["//buildenv/target:gce"],
+    # copybara:uncomment_end
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=NVGPUToLLVM",
+            ],
+            "include/NVGPUToLLVM/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/NVGPUToLLVM/Passes.td",
+    deps = ["//:td_files"],
+)
+
+gentbl_cc_library(
+    name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen",
+    # copybara:uncomment_begin
+    # compatible_with = ["//buildenv/target:gce"],
+    # copybara:uncomment_end
+    tbl_outs = [
+        (
+            [
+                "--gen-pass-decls",
+                "--name=TritonNVIDIAGPUToLLVM",
+            ],
+            "include/TritonNVIDIAGPUToLLVM/Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td",
+    deps = ["//:td_files"],
+)
+
+td_library(
+    name = "td_files",
+    srcs = glob(["include/Dialect/NVGPU/IR/*.td"]),
+    includes = ["include"],
+    deps = [
+        "@llvm-project//mlir:ArithOpsTdFiles",
+        "@llvm-project//mlir:CastInterfacesTdFiles",
+        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+        "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
+        "@llvm-project//mlir:FunctionInterfacesTdFiles",
+        "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+        "@llvm-project//mlir:LLVMOpsTdFiles",
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:PassBaseTdFiles",
+        "@llvm-project//mlir:SideEffectInterfacesTdFiles",
+        "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "nvgpu_ops_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-llvmir-conversions"],
+            "include/Dialect/NVGPU/IR/OpsConversions.inc",
+        ),
+        (
+            ["--gen-op-decls"],
+            "include/Dialect/NVGPU/IR/Ops.h.inc",
+        ),
+        (
+            ["--gen-op-defs"],
+            "include/Dialect/NVGPU/IR/Ops.cpp.inc",
+        ),
+        (
+            ["--gen-enum-decls"],
+            "include/Dialect/NVGPU/IR/OpsEnums.h.inc",
+        ),
+        (
+            ["--gen-enum-defs"],
+            "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "nvgpu_attr_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-attrdef-decls"],
+            "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc",
+        ),
+        (
+            ["--gen-attrdef-defs"],
+            "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td",
+    deps = ["td_files"],
+)
+
+gentbl_cc_library(
+    name = "nvgpu_dialect_inc_gen",
+    tbl_outs = [
+        (
+            ["--gen-dialect-decls"],
+            "include/Dialect/NVGPU/IR/Dialect.h.inc",
+        ),
+        (
+            ["--gen-dialect-defs"],
+            "include/Dialect/NVGPU/IR/Dialect.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td",
+    deps = ["td_files"],
+)
+
+cc_library(
+    name = "NVGPUDialect",
+    srcs = glob([
+        "lib/Dialect/NVGPU/IR/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/Dialect/NVGPU/IR/*.h",
+    ]),
+    copts = select({
+        "//:compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+            "-Wno-logical-op-parentheses",
+        ],
+    }),
+    includes = [
+        "..",  # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc
+        "../..",  # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h
+        "include",
+    ],
+    deps = [
+        ":nvgpu_attr_inc_gen",
+        ":nvgpu_dialect_inc_gen",
+        ":nvgpu_ops_inc_gen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:ControlFlowInterfaces",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FunctionInterfaces",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InliningUtils",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MathDialect",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        # The following is added to make Utility compile
+        "//:TritonTools",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD
new file mode 100644
index 000000000000..a5b34aa5c29b
--- /dev/null
+++ b/third_party/nvidia/backend/BUILD
@@ -0,0 +1,30 @@
+load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
+
+package(
+    default_applicable_licenses = ["//:license"],
+    default_visibility = [
+        "//third_party/py/triton:__subpackages__",
+    ],
+)
+
+pybind_extension(
+    name = "cuda_utils",
+    srcs = ["cuda_utils.cc"],
+    visibility = [
+        "//learning/deepmind/jax/triton/ops:__subpackages__",
+        "//third_party/py/triton:__subpackages__",
+    ],
+    deps = [
+        "//platforms/gpus/cuda/dynamic_libcuda",
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_config_cuda//cuda:cuda_runtime",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+filegroup(
+    name = "files",
+    srcs = glob(
+        include = ["**/*.py"],
+    ),
+)
diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c
index f476f5bffd73..ee36757bef2e 100644
--- a/third_party/nvidia/backend/driver.c
+++ b/third_party/nvidia/backend/driver.c
@@ -154,6 +154,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
 typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
     int *numClusters, CUfunction func, const CUlaunchConfig *config);
 
+#if CUDA_VERSION >= 12000
 typedef CUresult (*cuTensorMapEncodeTiled_t)(
     CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
     cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
@@ -161,6 +162,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
     const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
     CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
     CUtensorMapFloatOOBfill oobFill);
+#endif
 
 #define defineGetFunctionHandle(name, symbolName)                              \
   static symbolName##_t name() {                                               \
@@ -187,8 +189,10 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
 defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
                         cuOccupancyMaxActiveClusters);
 
+#if CUDA_VERSION >= 12000
 defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
                         cuTensorMapEncodeTiled);
+#endif
 
 static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
   int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
@@ -280,6 +284,9 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
 // Simple helper to experiment creating TMA descriptors on the host.
 // This is a useful to test TMA operations independently.
 static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
+#if CUDA_VERSION < 12000
+  return NULL;
+#else
   unsigned long long global_address;
   uint64_t dim;
   uint32_t tensorDim;
@@ -318,11 +325,15 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
       CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
       CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
   return Py_None;
+#endif
 }
 
 // Simple helper to experiment creating TMA descriptors on the host.
 // This is a useful to test TMA operations independently.
 static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
+#if CUDA_VERSION < 12000
+  return NULL;
+#else
   unsigned long long global_address;
   uint64_t dims[2];
   uint32_t tensorDims[2];
@@ -380,6 +391,7 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
       swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
       CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
   return Py_None;
+#endif
 }
 
 static PyMethodDef ModuleMethods[] = {
diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
index b075ca31a407..b474996f11c2 100644
--- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
+++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
@@ -291,10 +291,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern<ttn::WGMMAWaitGroupOp> {
 
   Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
     auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
-    uint32_t numOutputRegs = outputStructType.getBody().size();
-    std::string output =
-        outputStructType.getBody().front().isF32() ? "=f" : "=r";
-    return Constraints(numOutputRegs, output);
+    std::vector<std::string> outputConstraints;
+    outputConstraints.reserve(outputStructType.getBody().size());
+    for (mlir::Type type : outputStructType.getBody()) {
+      if (type.isF32()) {
+        outputConstraints.push_back("=f");
+        continue;
+      } else if (type.isF64()) {
+        outputConstraints.push_back("=d");
+        continue;
+      }
+      unsigned bitwidth = isa<LLVM::LLVMPointerType>(type) ?
+          64 : type.getIntOrFloatBitWidth();
+      switch (bitwidth) {
+        case 1:
+          outputConstraints.push_back("=b");
+          break;
+        case 16:
+          outputConstraints.push_back("=h");
+          break;
+        case 32:
+          outputConstraints.push_back("=r");
+          break;
+        case 64:
+          outputConstraints.push_back("=l");
+          break;
+        default:
+          assert(false && "unsupported bitwidth");
+      }
+    }
+    return outputConstraints;
   }
 
   OperandsAndConstraints
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
index d1086c189d33..86aa5c3c5ee6 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
@@ -496,6 +496,8 @@ Type getSharedMemTy(Type argType) {
     return type::f32Ty(ctx);
   else if (argType.getIntOrFloatBitWidth() == 8)
     return type::i8Ty(ctx);
+  else if (argType.getIntOrFloatBitWidth() == 16)
+    return type::i16Ty(ctx);
   else
     llvm::report_fatal_error("mma16816 data type not supported");
 }
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
index 2d16dc19b3b3..af897ef546dd 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
@@ -81,9 +81,9 @@ enum class TensorCoreType : uint8_t {
   FP32_TF32_TF32_FP32,
   FP16_FP16_FP16_FP16,
   FP32_FP8E5M2_FP8E5M2_FP32,
-  FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
-  FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
-  FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
+  FP32_FP8E5M2_FP8E4M3FN_FP32,
+  FP32_FP8E4M3FN_FP8E5M2_FP32,
+  FP32_FP8E4M3FN_FP8E4M3FN_FP32,
   // integer tensor core instr
   INT32_INT1_INT1_INT32, // Not implemented
   INT32_INT4_INT4_INT32, // Not implemented
@@ -112,9 +112,9 @@ Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
   case TensorCoreType::FP16_FP16_FP16_FP16:
     return fp16x2Pack2Ty;
   case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32:
-  case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32:
-  case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32:
-  case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32:
+  case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32:
+  case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32:
+  case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32:
     return fp32x4Ty;
   case TensorCoreType::INT32_INT8_INT8_INT32:
     return i32x4Ty;
@@ -140,14 +140,14 @@ TensorCoreType getMmaType(triton::DotOp op) {
         bTy.getElementType().isFloat8E5M2())
       return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
     if (aTy.getElementType().isFloat8E5M2() &&
-        bTy.getElementType().isFloat8E4M3FNUZ())
-      return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32;
-    if (aTy.getElementType().isFloat8E4M3FNUZ() &&
+        bTy.getElementType().isFloat8E4M3FN())
+      return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
+    if (aTy.getElementType().isFloat8E4M3FN() &&
         bTy.getElementType().isFloat8E5M2())
-      return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32;
-    if (aTy.getElementType().isFloat8E4M3FNUZ() &&
-        bTy.getElementType().isFloat8E4M3FNUZ())
-      return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32;
+      return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
+    if (aTy.getElementType().isFloat8E4M3FN() &&
+        bTy.getElementType().isFloat8E4M3FN())
+      return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
     if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
         op.getInputPrecision() == InputPrecision::TF32)
       return TensorCoreType::FP32_TF32_TF32_FP32;
@@ -193,11 +193,11 @@ inline static const std::map<TensorCoreType, std::string> mmaInstrPtxAmpere = {
 
     {TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32,
      "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"},
-    {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
+    {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32,
      "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"},
-    {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
+    {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32,
      "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"},
-    {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
+    {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32,
      "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"},
 };
 
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
index baed96a29704..41e36503f593 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
@@ -58,7 +58,7 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
     return triton::nvgpu::WGMMAEltType::s8;
   } else if (aTy.isFloat8E5M2()) {
     return triton::nvgpu::WGMMAEltType::e5m2;
-  } else if (aTy.isFloat8E4M3FNUZ()) {
+  } else if (aTy.isFloat8E4M3FN()) {
     return triton::nvgpu::WGMMAEltType::e4m3;
   } else {
     llvm::report_fatal_error("Unsupported mma operand type found");
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp
index 2fabb598e99d..80072edf3bdf 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -386,7 +386,7 @@ struct FpToFpOpConversion
   std::pair<ConverterT, size_t>
   getConversionFunc(Type srcTy, Type dstTy,
                     std::optional<RoundingMode> roundingMode) const {
-    auto F8E4M3TyID = TypeID::get<Float8E4M3FNUZType>();
+    auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
     auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
     auto F16TyID = TypeID::get<Float16Type>();
     auto BF16TyID = TypeID::get<BFloat16Type>();
@@ -430,7 +430,7 @@ struct FpToFpOpConversion
       llvm::report_fatal_error("Unsupported rounding mode for conversion.");
     }
     if (computeCapability < 89 &&
-        (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
+        (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
       llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
                       "compute capability >= 89"
                    << "\n";
@@ -452,7 +452,7 @@ struct FpToFpOpConversion
     auto dstElementType = getElementType(op.getResult());
     auto roundingMode = op.getRounding();
 
-    if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) {
+    if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
       assert(roundingMode.has_value() &&
              "Rounding mode must be specified for convertsions to fp8");
 
@@ -489,7 +489,7 @@ struct FpToFpOpConversion
 
     bool useFP16IntermediateSrc =
         srcElementType.isF32() &&
-        (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() ||
+        (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() ||
                                        dstElementType.isFloat8E5M2())) ||
          roundingMode.value() == RoundingMode::RTZ);
     bool isDstFP32 = dstElementType.isF32();
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
index 9ee532992d01..cc700e7186f4 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
@@ -272,6 +272,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
         ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
 
       if (other) {
+        if (otherIsSplatConstInt) {
+          for (size_t s = valueElemNBits; s < movWidth; s += valueElemNBits) {
+            splatVal |= splatVal << valueElemNBits;
+          }
+        }
+
         for (size_t ii = 0; ii < nWords; ++ii) {
           // PTX doesn't support mov.u8, so we need to use mov.u16
           PTXInstr &mov =
@@ -292,8 +298,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
           PTXInstr::Operand *opr{};
 
           if (otherIsSplatConstInt) {
-            for (size_t s = 0; s < 32; s += valueElemNBits)
-              splatVal |= splatVal << valueElemNBits;
             opr = ptxBuilder.newConstantOperand(splatVal);
           } else
             opr = ptxBuilder.newOperand(v, readConstraint);
diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc
index 1269dcda00aa..3cccc5fb6a1c 100644
--- a/third_party/nvidia/triton_nvidia.cc
+++ b/third_party/nvidia/triton_nvidia.cc
@@ -1,4 +1,4 @@
-#include "Dialect/NVGPU/IR/Dialect.h"
+#include "Dialect/NVGPU/IR/Dialect.h"
 #include "NVGPUToLLVM/NVGPUToLLVMPass.h"
 #include "TritonNVIDIAGPUToLLVM/Passes.h"
 #include "cublas_instance.h"
diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include
index fe4f4a1aa9bd..4400934bdf78 120000
--- a/third_party/proton/proton/_C/include
+++ b/third_party/proton/proton/_C/include
@@ -1 +1 @@
-../../csrc/include/
\ No newline at end of file
+../../csrc/include
\ No newline at end of file
diff --git a/unittest/BUILD b/unittest/BUILD
new file mode 100644
index 000000000000..cb885459e19c
--- /dev/null
+++ b/unittest/BUILD
@@ -0,0 +1,144 @@
+load("//tools/build_defs/build_test:build_test.bzl", "build_test")
+
+package(
+    default_applicable_licenses = ["//:license"],
+    default_compatible_with = ["//buildenv/target:gce"],
+    default_visibility = ["//:__subpackages__"],
+)
+
+cc_test(
+    name = "AnalysisTest",
+    srcs = glob(["Analysis/*.cpp"]),
+    deps = [
+        "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
+        "//:TritonDialects",
+    ],
+)
+
+cc_test(
+    name = "DialectTestCatchAll",
+    srcs = glob(
+        [
+            "Dialect/**/*.cpp",
+        ],
+        exclude = [
+            "Dialect/TritonGPU/DialectTest.cpp",
+            "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp",
+            "Dialect/TritonGPU/SwizzleTest.cpp",
+        ],
+    ),
+    copts = select({
+        "//:compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    deps = [
+        "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AsmParser",
+        "@llvm-project//mlir:IR",
+        "//:TritonDialects",
+    ],
+)
+
+cc_test(
+    name = "DialectTest",
+    srcs = [
+        "Dialect/TritonGPU/DialectTest.cpp",
+    ],
+    copts = select({
+        "//:compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    deps = [
+        "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AsmParser",
+        "@llvm-project//mlir:IR",
+        "//:TritonDialects",
+    ],
+)
+
+cc_test(
+    name = "LinearLayoutConversionsTest",
+    srcs = [
+        "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp",
+    ],
+    copts = select({
+        "//:compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    deps = [
+        "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AsmParser",
+        "@llvm-project//mlir:IR",
+        "//:TritonDialects",
+    ],
+)
+
+cc_test(
+    name = "SwizzleTest",
+    srcs = [
+        "Dialect/TritonGPU/SwizzleTest.cpp",
+    ],
+    copts = select({
+        "//:compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    deps = [
+        "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AsmParser",
+        "@llvm-project//mlir:IR",
+        "//:TritonDialects",
+    ],
+)
+
+cc_test(
+    name = "ConversionTest",
+    srcs = glob(
+        [
+            "Conversion/**/*.cpp",
+            "Conversion/**/*.h",
+        ],
+        exclude = [
+            "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp",
+            "Conversion/TritonGPUToLLVM/DumpLayout.cpp",
+            "Conversion/TritonGPUToLLVM/DumpLayout.h",
+        ],
+    ),
+    copts = select({
+        "//:compiler_is_msvc": [],
+        "//conditions:default": [
+            "-Wno-unused-variable",
+        ],
+    }),
+    deps = [
+        "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:IR",
+        "//:TritonDialects",
+        "//:TritonNvidiaGPUTransforms",
+        "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
+    ],
+)
+
+build_test(
+    name = "build_test",
+    allow_empty_target = False,
+    targets = [
+        ":ConversionTest",
+        ":AnalysisTest",
+        ":DialectTest",
+    ],
+)