From 8243b50282b890ee40c00cd478819143d7749d76 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 11 Dec 2023 17:32:31 -0800 Subject: [PATCH] Add Pattern Boundary Marking API (#5930) Co-authored-by: Siyuan Liu Co-authored-by: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> --- WORKSPACE | 8 + bazel/nlohmann_json.BUILD | 9 + test/stablehlo/test_mark_pattern.py | 183 +++++++ torch_xla/csrc/init_python_bindings.cpp | 17 +- torch_xla/csrc/ops/mark_tensor.cpp | 35 ++ torch_xla/csrc/ops/mark_tensor.h | 24 + torch_xla/csrc/ops/xla_ops.cpp | 1 + torch_xla/csrc/ops/xla_ops.h | 1 + torch_xla/csrc/runtime/BUILD | 13 + .../runtime/stablehlo_composite_helper.cc | 465 ++++++++++++++++++ .../csrc/runtime/stablehlo_composite_helper.h | 20 + torch_xla/csrc/runtime/stablehlo_helper.cc | 7 + torch_xla/csrc/tensor_methods.cpp | 7 + torch_xla/csrc/tensor_methods.h | 2 + torch_xla/experimental/mark_pattern_utils.py | 81 +++ torch_xla/experimental/xla_marker.py | 90 ++++ 16 files changed, 962 insertions(+), 1 deletion(-) create mode 100644 bazel/nlohmann_json.BUILD create mode 100644 test/stablehlo/test_mark_pattern.py create mode 100644 torch_xla/csrc/ops/mark_tensor.cpp create mode 100644 torch_xla/csrc/ops/mark_tensor.h create mode 100644 torch_xla/csrc/runtime/stablehlo_composite_helper.cc create mode 100644 torch_xla/csrc/runtime/stablehlo_composite_helper.h create mode 100644 torch_xla/experimental/mark_pattern_utils.py create mode 100644 torch_xla/experimental/xla_marker.py diff --git a/WORKSPACE b/WORKSPACE index 76a20826678f..6d9f43678da8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,6 +16,14 @@ http_archive( urls = ["https://github.com/pybind/pybind11/archive/442261da585536521ff459b1457b2904895f23b4.tar.gz"], ) +http_archive( + name = "com_nlohmann_json", + build_file = "//bazel:nlohmann_json.BUILD", + sha256 = "d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273", + strip_prefix = "json-3.11.2", + url = "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz", +) + load("@pybind11_bazel//:python_configure.bzl", "python_configure") # This is required for setting up the linkopts for -lpython.q diff --git a/bazel/nlohmann_json.BUILD b/bazel/nlohmann_json.BUILD new file mode 100644 index 000000000000..eb868ef5c6aa --- /dev/null +++ b/bazel/nlohmann_json.BUILD @@ -0,0 +1,9 @@ +cc_library( + name = "json", + hdrs = [ + "single_include/nlohmann/json.hpp", + "single_include/nlohmann/json_fwd.hpp", + ], + includes = ["single_include"], + visibility = ["//visibility:public"], +) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py new file mode 100644 index 000000000000..1980e236f75a --- /dev/null +++ b/test/stablehlo/test_mark_pattern.py @@ -0,0 +1,183 @@ +import sys +import unittest + +import torch +import torch.nn.functional as F +import torch_xla.core.xla_model as xm +import torch_xla.experimental.xla_marker +from torch.utils import _pytree as pytree +from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder + + +class XlaMarkPatternTest(unittest.TestCase): + + def run_func_get_stablehlo(self, f, input_args): + + device = xm.xla_device() + input_args = pytree.tree_map_only(torch.Tensor, + lambda x: x.to(device=device), input_args) + out = f(*input_args) + if isinstance(out, tuple): + out = list(out) + else: + out = [out] + stablehlo = xm.get_stablehlo(out) + return stablehlo + + def test_basic(self): + + def f(x): + x = x + 1 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + x = x + 2 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False) + return x + + input_args = (torch.randn(5),) + stablehlo = self.run_func_get_stablehlo(f, input_args) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) + self.assertTrue('{attributes = {}, name = "p"}' in stablehlo) + + def test_sdpa_pattern(self): + import torch.nn.functional as F + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + q, k, v = x.split(128, dim=-2) + q = torch.ops.xla_pattern_marking.mark_tensor( + q, "sdpa", pos=0, id="0", is_input=True) + k = torch.ops.xla_pattern_marking.mark_tensor( + k, "sdpa", pos=1, id="0", is_input=True) + v = torch.ops.xla_pattern_marking.mark_tensor( + v, "sdpa", pos=2, id="0", is_input=True) + attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) + attn_out = torch.ops.xla_pattern_marking.mark_tensor( + attn_out, + "sdpa", + pos=0, + id="0", + is_input=False, + attr={"scale": 0.25}) + q, k, v = y.split(128, dim=-2) + q = torch.ops.xla_pattern_marking.mark_tensor( + q, "sdpa", pos=0, id="1", is_input=True) + k = torch.ops.xla_pattern_marking.mark_tensor( + k, "sdpa", pos=1, id="1", is_input=True) + v = torch.ops.xla_pattern_marking.mark_tensor( + v, "sdpa", pos=2, id="1", is_input=True) + attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) + attn_out2 = torch.ops.xla_pattern_marking.mark_tensor( + attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) + return attn_out, attn_out2 + + input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) + stablehlo = self.run_func_get_stablehlo(M(), input_args) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 2) + self.assertTrue( + '{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in + stablehlo) + self.assertTrue( + '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) + + def test_composite_builder_sdpa_pattern(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) + q, k, v = x.split(128, dim=-2) + q, k, v = b.mark_inputs(q, k, v) + attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) + attn_out = b.mark_outputs(attn_out) + + b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2}) + q, k, v = y.split(128, dim=-2) + q, k, v = b2.mark_inputs(q, k, v) + attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) + attn_out2 = b2.mark_outputs(attn_out2) + return attn_out, attn_out2 + + input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) + stablehlo = self.run_func_get_stablehlo(M(), input_args) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 2) + self.assertTrue( + '{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in + stablehlo) + self.assertTrue( + '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) + + def test_multiple_input(self): + + def f(x, y): + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + out = x + y + out = out * x * y + out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False) + return out + + input_args = (torch.ones(5), torch.ones(5)) + stablehlo = self.run_func_get_stablehlo(f, input_args) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) + self.assertTrue('{attributes = {}, name = "p"}' in stablehlo) + + @unittest.skip("Multiple outputs patterns are not supported now.") + def test_multiple_output(self): + + def f(x, y): + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + out1 = x + y + out2 = x * y + out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False) + out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False) + return out1, out2 + + input_args = (torch.ones(5), torch.ones(5)) + stablehlo = self.run_func_get_stablehlo(f, input_args) + + @unittest.skip("Nested pattern is not supported now.") + def test_nested_pattern(self): + + def f(x): + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = x + 1 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = x + 1 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) + x = x * 2 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", + False) + return x + + input_args = (torch.ones(5),) + stablehlo = self.run_func_get_stablehlo(f, input_args) + + @unittest.skip("Nested pattern is not supported now.") + def test_tangent_output(self): + # Special case of nested pattern, outputs don't have dependencies. + def f(x): + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = x + 1 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = x + 1 + y = x - 1 + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0", + False) + return x, y + + input_args = (torch.ones(5),) + stablehlo = self.run_func_get_stablehlo(f, input_args) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e36dca23f199..03c5fa95ce1b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -740,6 +740,12 @@ void MapXlaEnvVarsToLazy() { runtime::sys_util::GetEnvInt("XLA_TRIM_GRAPH_SIZE", 100000); } +at::Tensor MarkTensor(const at::Tensor& input, const std::string& info) { + XLATensorPtr result = + tensor_methods::mark_tensor(bridge::GetXlaTensor(input), info); + return bridge::AtenFromXlaTensor(std::move(result)); +} + std::string GetPyTypeString(py::handle obj) { std::string type = obj.attr("__class__").attr("__name__").cast(); return type; @@ -2172,7 +2178,16 @@ void InitXlaModuleBindings(py::module m) { } return handles; }); - + m.def("_xla_mark_tensor", + [](const at::Tensor& input, const std::string& info) { + TORCH_LAZY_COUNTER("XlaMarkTensor", 1); + at::Tensor result; + { + NoGilSection nogil; + result = MarkTensor(input, info); + } + return result; + }); m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); XLATensorPtr xtensor = bridge::GetXlaTensor(input); diff --git a/torch_xla/csrc/ops/mark_tensor.cpp b/torch_xla/csrc/ops/mark_tensor.cpp new file mode 100644 index 000000000000..6db158440813 --- /dev/null +++ b/torch_xla/csrc/ops/mark_tensor.cpp @@ -0,0 +1,35 @@ +#include "torch_xla/csrc/ops/mark_tensor.h" + +#include + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/shape_helper.h" + +namespace torch_xla { + +MarkTensor::MarkTensor(const torch::lazy::Value& input, const std::string& info) + : XlaNode(xla_mark_tensor, {input}, GetXlaShape(input), + /*num_outputs=*/1, torch::lazy::MHash(info)), + info_(info) {} + +torch::lazy::NodePtr MarkTensor::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0), info_); +} + +XlaOpVector MarkTensor::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input); + static const std::string opname = "xla_mark_tensor"; + xla::XlaOp output = + xla::CustomCall(input.builder(), opname, {input}, input_shape, info_); + return ReturnOp(output, loctx); +} + +std::string MarkTensor::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", info=" << info_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mark_tensor.h b/torch_xla/csrc/ops/mark_tensor.h new file mode 100644 index 000000000000..ae177e8b0821 --- /dev/null +++ b/torch_xla/csrc/ops/mark_tensor.h @@ -0,0 +1,24 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ +#define XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class MarkTensor : public XlaNode { + public: + MarkTensor(const torch::lazy::Value& input, const std::string& info); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + std::string info_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index a9a7f9d62fcc..e515b64350d7 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -15,6 +15,7 @@ const OpKindWrapper xla_diagonal_view_update("xla::diagonal_view_update"); const OpKindWrapper xla_einsum_backward("xla::einsum_backward"); const OpKindWrapper xla_generic_slice("xla::generic_slice"); const OpKindWrapper xla_get_dimensions_size("xla::xla_get_dimensions_size"); +const OpKindWrapper xla_mark_tensor("xla::mark_tensor"); const OpKindWrapper xla_moving_average("xla::moving_average"); const OpKindWrapper xla_nms("xla::nms"); const OpKindWrapper xla_not_supported("xla::not_supported"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index f39227dd6dd0..fa6c0525da48 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -41,6 +41,7 @@ extern const OpKindWrapper xla_diagonal_view_update; extern const OpKindWrapper xla_einsum_backward; extern const OpKindWrapper xla_generic_slice; extern const OpKindWrapper xla_get_dimensions_size; +extern const OpKindWrapper xla_mark_tensor; extern const OpKindWrapper xla_moving_average; extern const OpKindWrapper xla_nms; extern const OpKindWrapper xla_not_supported; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9d58adaa9442..0df2c215219e 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -266,6 +266,18 @@ cc_library( ], ) +cc_library( + name = "stablehlo_composite_helper", + srcs = ["stablehlo_composite_helper.cc"], + hdrs = ["stablehlo_composite_helper.h"], + deps = [ + ":types", + ":xla_util", + "@com_nlohmann_json//:json", + "@xla//xla/mlir_hlo:all_passes", + ], +) + cc_library( name = "stablehlo_helper", srcs = ["stablehlo_helper.cc"], @@ -273,6 +285,7 @@ cc_library( deps = [ ":types", ":xla_util", + ":stablehlo_composite_helper", "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", "@xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc new file mode 100644 index 000000000000..bc38fc659978 --- /dev/null +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc @@ -0,0 +1,465 @@ +#include "torch_xla/csrc/runtime/stablehlo_composite_helper.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LogicalResult.h" +#include "single_include/nlohmann/json.hpp" +#include "stablehlo/dialect/StablehloOps.h" + +namespace torch_xla { +namespace runtime { + +namespace { + +using nlohmann::json; + +static bool IsXlaMarkTensorOp(mlir::Operation* op) { + if (op == nullptr) { + return false; + } + if (op->getNumOperands() != 1 || op->getNumResults() != 1) { + return false; + } + if (!llvm::isa(op)) { + return false; + } + auto target_name = + op->getAttr("call_target_name").dyn_cast(); + if (target_name == nullptr || target_name.str() != "xla_mark_tensor") { + return false; + } + return true; +} + +struct BoundaryMetadata { + std::string name; + std::string id; + int64_t pos; + bool is_input; + std::unordered_map attrs; + + auto boundary_key() const { return std::forward_as_tuple(name, id); } + + auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); } + + bool operator==(const BoundaryMetadata& other) const { + return uid() == other.uid(); + } + bool operator<(const BoundaryMetadata& other) const { + return uid() < other.uid(); + } + + static std::unique_ptr Parse(llvm::StringRef str) { + auto j = json::parse(str, /*cb=*/nullptr, /*allow_exceptions=*/false); + return Build(j); + } + + private: + template + static bool CopyJsonValue(const nlohmann::basic_json<>& j, + llvm::StringRef key, json::value_t expected_type, + T& to) { + auto kv = j.find(key); + + if (kv == j.end()) { + return false; + } + if (kv.value().type() != expected_type) { + return false; + } + kv.value().get_to(to); + return true; + } + + static std::unique_ptr Build( + const nlohmann::basic_json<>& j) { + BoundaryMetadata metadata; + + bool is_valid_metadata_json = + CopyJsonValue(j, "name", json::value_t::string, metadata.name) && + CopyJsonValue(j, "id", json::value_t::string, metadata.id) && + CopyJsonValue(j, "pos", json::value_t::number_unsigned, metadata.pos) && + CopyJsonValue(j, "is_input", json::value_t::boolean, metadata.is_input); + + if (!is_valid_metadata_json) { + return nullptr; + } + + if (auto kv = j.find("attr"); kv != j.end() && kv.value().is_object()) { + auto& attrs_j = kv.value(); + for (auto attr_j = attrs_j.begin(); attr_j != attrs_j.end(); ++attr_j) { + metadata.attrs.insert({attr_j.key(), attr_j.value()}); + } + } + return std::make_unique(std::move(metadata)); + } +}; + +class BuildStableHLOCompositePass : public mlir::OperationPass { + public: + explicit BuildStableHLOCompositePass() + : mlir::OperationPass::OperationPass( + mlir::TypeID::get()) {} + + ~BuildStableHLOCompositePass() override = default; + + void runOnOperation() override { + mlir::ModuleOp module_op = getOperation(); + llvm::SmallVector func_ops( + module_op.getOps()); + for (mlir::func::FuncOp& func_op : func_ops) { + llvm::DenseMap op_order_map = + BuildOpOrderMap(func_op); + for (auto op : func_op.getOps()) { + if (mlir::failed( + BuildStableHLOComposite(op.getOperation(), op_order_map))) { + op.emitError() << "failed to build composite."; + return signalPassFailure(); + } + } + } + } + + mlir::StringRef getName() const override { + return llvm::getTypeName(); + } + + std::unique_ptr clonePass() const override { + return std::make_unique(*this); + } + + private: + llvm::DenseMap BuildOpOrderMap( + mlir::func::FuncOp func_op) const { + llvm::DenseMap op_order_map; + for (const auto& op : llvm::enumerate(func_op.getOps())) { + op_order_map[&op.value()] = op.index(); + } + return op_order_map; + } + + mlir::FailureOr> GetBoundaryMetadata( + mlir::Operation* op) { + if (!IsXlaMarkTensorOp(op)) { + return mlir::FailureOr(nullptr); + } + auto backend_config = + op->getAttr("backend_config").dyn_cast(); + if (backend_config == nullptr) { + return mlir::FailureOr(nullptr); + } + std::unique_ptr metadata = + BoundaryMetadata::Parse(backend_config); + if (metadata == nullptr) { + return op->emitError() << "invalid boundary metadata JSON."; + } + return metadata; + } + + mlir::FailureOr BuildDictionaryAttrFromJsonMap( + mlir::OpBuilder& builder, + const std::unordered_map& json_map) { + llvm::SmallVector named_attrs; + for (auto& [key, j] : json_map) { + switch (j.type()) { + case json::value_t::number_integer: + case json::value_t::number_unsigned: + named_attrs.push_back( + {builder.getStringAttr(key), + builder.getI64IntegerAttr(j.template get())}); + break; + case json::value_t::number_float: + named_attrs.push_back( + {builder.getStringAttr(key), + builder.getF32FloatAttr(j.template get())}); + break; + case json::value_t::boolean: + named_attrs.push_back({builder.getStringAttr(key), + builder.getBoolAttr(j.template get())}); + break; + case json::value_t::string: + named_attrs.push_back( + {builder.getStringAttr(key), + builder.getStringAttr(j.template get())}); + break; + default: + return mlir::failure(); + } + } + return builder.getDictionaryAttr(named_attrs); + } + + mlir::LogicalResult BuildStableHLOComposite( + mlir::Operation* op, + const llvm::DenseMap& op_order_map) { + auto metadata_or = GetBoundaryMetadata(op); + if (mlir::failed(metadata_or)) { + return mlir::failure(); + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + return mlir::success(); + } + + auto args_ops_or = GetBoundaryArgsAndOps(op, *metadata, op_order_map); + if (mlir::failed(args_ops_or)) { + return mlir::failure(); + } + + auto [args, impl_ops] = *args_ops_or; + + mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc( + op, absl::StrCat(metadata->name, ".impl"), args, impl_ops); + + mlir::FailureOr composite_op_or = + BuildStableHLOCompositeOp(op, impl_func, args, *metadata); + if (mlir::failed(composite_op_or)) { + return mlir::failure(); + } + mlir::Operation* composite_op = *composite_op_or; + + // Updates all users of this op's result(s) to use the results(s) of impl + // func call. + for (size_t i = 0; i < op->getNumResults(); ++i) { + mlir::OpResult result = op->getResult(i); + result.replaceAllUsesWith(composite_op->getResult(i)); + } + + // The unused impl_ops will be eliminated with canonicalizer. + return mlir::success(); + } + + mlir::FailureOr, + llvm::SmallVector>> + GetBoundaryArgsAndOps( + mlir::Operation* boundary_output_op, const BoundaryMetadata& metadata, + const llvm::DenseMap& op_order_map) { + llvm::SetVector impl_ops_setvec; + llvm::SetVector> arg_pos_setvec; + llvm::SmallVector processing({boundary_output_op}); + + // Reverse graph traversal: from boundary output op to boundary input op, + // global function arg, or stablehlo constant. + while (!processing.empty()) { + mlir::Operation* curr_op = processing.back(); + processing.pop_back(); + if (impl_ops_setvec.contains(curr_op)) { + continue; + } + + auto curr_metadata_or = GetBoundaryMetadata(curr_op); + if (mlir::failed(curr_metadata_or)) { + return mlir::failure(); + } + std::unique_ptr curr_metadata = + std::move(*curr_metadata_or); + if (curr_metadata != nullptr) { + if (curr_metadata->is_input && + curr_metadata->boundary_key() == metadata.boundary_key()) { + // Terminal condition: boundary input op. + arg_pos_setvec.insert({curr_op->getResult(0).dyn_cast(), + curr_metadata->pos}); + continue; + } + } + + impl_ops_setvec.insert(curr_op); + for (mlir::Value value : curr_op->getOperands()) { + mlir::Operation* def_op = value.getDefiningOp(); + if (def_op == nullptr) { + // Terminal condition: global function arg + arg_pos_setvec.insert({value, std::numeric_limits::max()}); + } else if (llvm::isa(def_op)) { + // Terminal condition: constant + impl_ops_setvec.insert(def_op); + } else { + processing.push_back(def_op); + } + } + } + // Sorts all ops within the boundary by their line numbers in the input + // MLIR. The ops will be duplicated to the impl function following this + // order. + llvm::SmallVector impl_ops = impl_ops_setvec.takeVector(); + for (auto& op : impl_ops) { + if (!op_order_map.contains(op)) { + return op->emitError() + << "does not have a ordering number in its outer func."; + } + } + std::sort(impl_ops.begin(), impl_ops.end(), + [&op_order_map](const auto& a, const auto& b) { + return op_order_map.at(a) < op_order_map.at(b); + }); + + // Sorts boundary args by their positions. Note that the args of the + // composite and impl function may be more than the boundary inputs, because + // the MLIR is lowered from the functionalized graph and additional args may + // be Pytorch constants. In such case the position of those args would be + // undetermined, while they would always come after boundary inputs. + auto arg_pos_pairs = arg_pos_setvec.takeVector(); + std::stable_sort( + arg_pos_pairs.begin(), arg_pos_pairs.end(), + [](const auto& a, const auto& b) { return a.second < b.second; }); + llvm::SmallVector args; + args.reserve(arg_pos_pairs.size()); + for (auto& [arg, unused] : arg_pos_pairs) { + args.push_back(arg); + } + + return std::make_pair(std::move(args), std::move(impl_ops)); + } + + mlir::func::FuncOp BuildStableHLOCompositeImplFunc( + mlir::Operation* boundary_output_op, llvm::StringRef func_name, + const llvm::SmallVector& args, + const llvm::SmallVector& impl_ops) { + mlir::ModuleOp module_op = getOperation(); + mlir::MLIRContext* context = &getContext(); + mlir::OpBuilder builder(context); + + // Creates composite impl function and duplicates all ops within the + // boundary in the function. + llvm::SmallVector arg_locs; + llvm::SmallVector arg_types; + for (auto& arg : args) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + llvm::SmallVector result_types( + boundary_output_op->getResultTypes().begin(), + boundary_output_op->getResultTypes().end()); + + mlir::func::FuncOp impl_func = builder.create( + module_op.getLoc(), func_name, + mlir::FunctionType::get(context, arg_types, result_types)); + mlir::IRMapping mapping; + builder.createBlock(&impl_func.getBody(), impl_func.begin(), arg_types, + arg_locs); + for (const auto& arg : llvm::enumerate(args)) { + mapping.map(arg.value(), impl_func.getArgument(arg.index())); + } + for (mlir::Operation* original_op : impl_ops) { + mlir::Operation* cloned_op = builder.clone(*original_op, mapping); + mapping.map(original_op, cloned_op); + } + builder.create( + impl_func.getBody().getLoc(), + mapping.lookup(boundary_output_op)->getResults()); + + // Adds the new function to symbol table. + mlir::SymbolTable symbol_table(module_op); + impl_func.setPrivate(); + symbol_table.insert(impl_func); + + return impl_func; + } + + mlir::FailureOr BuildStableHLOCompositeOp( + mlir::Operation* boundary_output_op, mlir::func::FuncOp impl_func, + const llvm::SmallVector& args, + const BoundaryMetadata& metadata) { + mlir::ModuleOp module_op = getOperation(); + mlir::MLIRContext* context = &getContext(); + mlir::OpBuilder builder(context); + + mlir::FailureOr attributes_or = + BuildDictionaryAttrFromJsonMap(builder, metadata.attrs); + if (mlir::failed(attributes_or)) { + return boundary_output_op->emitError() + << "failed to transform boundary attr " + "JSON into composite attributes."; + } + + builder.setInsertionPointAfter(boundary_output_op); + llvm::SmallVector call_attrs{ + { + builder.getStringAttr("call_target_name"), + builder.getStringAttr("stablehlo.composite"), + }, + { + builder.getStringAttr("called_computations"), + builder.getArrayAttr(mlir::FlatSymbolRefAttr::get( + builder.getContext(), impl_func.getSymName())), + }, + { + builder.getStringAttr("composite.backend_config"), + builder.getDictionaryAttr(llvm::SmallVector{ + { + builder.getStringAttr("attributes"), + *attributes_or, + }, + { + builder.getStringAttr("name"), + builder.getStringAttr(metadata.name), + }, + }), + }, + }; + + // Creates and inserts composite call op. + mlir::Operation* composite_op = + builder.create( + boundary_output_op->getLoc(), + impl_func.getFunctionType().getResults(), args, call_attrs); + return composite_op; + } +}; + +class RemoveXlaMarkTensorOpsPass + : public mlir::OperationPass { + public: + explicit RemoveXlaMarkTensorOpsPass() + : mlir::OperationPass::OperationPass( + mlir::TypeID::get()) {} + + ~RemoveXlaMarkTensorOpsPass() override = default; + + void runOnOperation() override { + mlir::func::FuncOp func_op = getOperation(); + llvm::SmallVector ops_to_erase; + + for (auto op : func_op.getOps()) { + if (!IsXlaMarkTensorOp(op.getOperation())) { + continue; + } + mlir::Value original_value = op.getOperand(0); + + for (mlir::Value result : op.getResults()) { + result.replaceAllUsesWith(original_value); + } + } + + // The unused custom_call ops will be eliminated with canonicalizer. + } + + mlir::StringRef getName() const override { + return llvm::getTypeName(); + } + + std::unique_ptr clonePass() const override { + return std::make_unique(*this); + } +}; + +} // namespace + +std::unique_ptr> +CreateBuildStableHLOCompositePass() { + return std::make_unique(); +} + +std::unique_ptr> +CreateRemoveXlaMarkTensorOpsPass() { + return std::make_unique(); +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.h b/torch_xla/csrc/runtime/stablehlo_composite_helper.h new file mode 100644 index 000000000000..e34f4cfcf877 --- /dev/null +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.h @@ -0,0 +1,20 @@ +#ifndef STABLEHLO_COMPOSITE_HELPER_H_ +#define STABLEHLO_COMPOSITE_HELPER_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace torch_xla { +namespace runtime { + +std::unique_ptr> +CreateBuildStableHLOCompositePass(); + +std::unique_ptr> +CreateRemoveXlaMarkTensorOpsPass(); + +} // namespace runtime +} // namespace torch_xla + +#endif diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cc b/torch_xla/csrc/runtime/stablehlo_helper.cc index 735e6cb474ff..ccf61984f770 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_helper.cc @@ -10,6 +10,7 @@ #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/stablehlo_composite_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -81,6 +82,12 @@ static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module, // Canonicalization after tuple flatten, to remove unused tuple op. pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + // Group patterns into StableHLO composites. + pm.addPass(torch_xla::runtime::CreateBuildStableHLOCompositePass()); + pm.addNestedPass( + torch_xla::runtime::CreateRemoveXlaMarkTensorOpsPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createCSEPass()); if (!mlir::succeeded(pm.run(*mlir_module))) { return absl::Status( absl::StatusCode::kInternal, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index a2a18c45142c..705ddb6c07a2 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -64,6 +64,7 @@ #include "torch_xla/csrc/ops/linspace.h" #include "torch_xla/csrc/ops/log_softmax.h" #include "torch_xla/csrc/ops/logsumexp.h" +#include "torch_xla/csrc/ops/mark_tensor.h" #include "torch_xla/csrc/ops/masked_scatter.h" #include "torch_xla/csrc/ops/masked_select.h" #include "torch_xla/csrc/ops/max_in_dim.h" @@ -1655,6 +1656,12 @@ XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other) { return DispatchComparisonOp(at::aten::lt, input, other); } +XLATensorPtr mark_tensor(const XLATensorPtr& input, const std::string& info) { + torch::lazy::NodePtr node = + torch::lazy::MakeNode(input->GetIrValue(), info); + return input->CreateFrom(torch::lazy::Value(node)); +} + XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, const XLATensorPtr& source) { torch::lazy::ScopePusher ir_scope(at::aten::masked_scatter.toQualString()); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 7c51c64a9c14..c25c25e52011 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -523,6 +523,8 @@ XLATensorPtr lt(const XLATensorPtr& input, const at::Scalar& other); XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other); +XLATensorPtr mark_tensor(const XLATensorPtr& input, const std::string& info); + XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, const XLATensorPtr& source); diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py new file mode 100644 index 000000000000..b02a702f9ec4 --- /dev/null +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -0,0 +1,81 @@ +import uuid +from typing import Dict, Union + +import torch +from torch_xla.experimental import xla_marker + + +class StableHLOCompositeBuilder: + """ + Helper for building a StableHLO Composite by marking input and output tensors. It + should be used with the StableHLO converters from `torch_xla.stablehlo`. + + Args: + name (str): + The name of the built StableHLO Composite op. + attr (dict): + Attributes of the StableHLO Composite op. + """ + + def __init__(self, name: str, attr: Dict[str, Union[int, float, str]] = None): + + self.attr = attr + self.name = name + self.id = uuid.uuid4().hex + self._inputs = [] + self._outputs = [] + + def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): + marked_tensors = [] + for pos, tensor in enumerate(tensors): + if not isinstance(tensor, torch.Tensor): + raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.") + marked_tensors.append( + torch.ops.xla_pattern_marking.mark_tensor( + tensor, + name=self.name, + pos=pos, + id=self.id, + is_input=is_input, + attr=self.attr if not is_input else None, + )) + + if len(marked_tensors) == 1: + return marked_tensors[0] + return tuple(marked_tensors) + + def mark_inputs(self, *tensors: torch.Tensor): + """ + Mark the input tensors of the StableHLO Composite. This method must only be + called once per builder. + + Args: + *tensors (torch.Tensor): + Torch tensors to mark. + Returns: + marked_tensors (torch.Tensor or Tuple[torch.Tensor]): + Torch tensors marked as composite inputs. The tensor inputs of this method + should be replaced by the marked tensors in later usages. + """ + + return self._mark_tensor(*tensors, is_input=True) + + def mark_outputs(self, *tensors: torch.Tensor): + """ + Mark the output tensors of the StableHLO Composite. This method must only be + called once per builder. + + Args: + *tensors (torch.Tensor): + Torch tensors to mark. + Returns: + marked_tensors (torch.Tensor or Tuple[torch.Tensor]): + Torch tensors marked as composite outputs. The tensor inputs of this method + should be replaced by the marked tensors in later usages. + """ + + if len(tensors) > 1: + # TODO: Allow multiple composite outputs + raise ValueError( + f"StableHLO composite with more than one outputs is not supported.") + return self._mark_tensor(*tensors, is_input=False) diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py new file mode 100644 index 000000000000..f967fb9b787d --- /dev/null +++ b/torch_xla/experimental/xla_marker.py @@ -0,0 +1,90 @@ +import dataclasses +import json +from dataclasses import dataclass +from typing import Dict + +import torch +import torch_xla +from torch.library import Library, impl + +xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF") + +xla_pattern_marking_lib.define( + "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor" +) + + +@dataclass +class BoundaryMetadata: + name: str # Name of the Patttern. + pos: int # Arg/return position. + id: str # Patten instance id. + is_input: bool = True # If the marked tensor is input/output. + attr: dict = None # Attribute of the pattern, expected to be attached to output. + + +class BoundaryMetadataSerializer(json.JSONEncoder): + + def default(self, obj): + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + return super().default(obj) + + +def _assert_valid_composite_attr(attr): + if attr is None: + return + if not isinstance(attr, dict): + raise ValueError("Composite attr must be a Python dictionary.") + + for k, v in attr.items(): + if not isinstance(k, str): + raise ValueError("Composite attr name must be a Python str.") + if type(k) not in [str, float, int]: + raise ValueError( + "Composite attr value must be either Python str, float, or int.") + + +@impl(xla_pattern_marking_lib, "mark_tensor", "XLA") +def mark_tensor_xla(x: torch.Tensor, + name: str, + pos: int, + id: str, + is_input: bool, + attr: Dict = None): + """Attach pattern boundary metadata to a XLA Tensor. + + Args: + x: torch.Tensor (On XLA device) - the marked tensor. + name: str - The name of the pattern, it will be the name of the stablehlo composite op. + pos: int - Input/output Position of the annotated tensor in the pattern. + id: str - Unique identifier of the pattern instance. + is_input: bool - If the annotated tensor is the input to the pattern. + attr: dict - Attribute of the pattern, it will be passed down to the attribute field + in the stablehlo composite. + """ + _assert_valid_composite_attr(attr) + pattern_info = BoundaryMetadata(name, pos, id, is_input, attr) + return torch_xla._XLAC._xla_mark_tensor( + x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer)) + + +@impl(xla_pattern_marking_lib, "mark_tensor", "CompositeExplicitAutograd") +def mark_tensor(x: torch.Tensor, + name: str, + pos: int, + id: str, + is_input: bool, + attr: Dict = None): + # Do nothing for non-xla tensor. + return x + + +@impl(xla_pattern_marking_lib, "mark_tensor", "Meta") +def mark_tensor_meta(x: torch.Tensor, + name: str, + pos: int, + id: str, + is_input: bool, + attr: Dict = None): + return torch.empty_like(x)