From 0b311d0047e2eeff4b463d5bcd59ed341b7458b9 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:46:14 -0700 Subject: [PATCH] Update XLA pin to 2023/08/08 (#5427) * Update XLA pin to 2023/08/08 * Update pin one more time to pass a gs file system bug * manually add gcs to our dep * fix cpp test build --- WORKSPACE | 4 ++-- bazel/rules_def.bzl | 2 +- setup.py | 2 +- torch_xla/csrc/init_python_bindings.cpp | 9 ++++----- torch_xla/csrc/runtime/BUILD | 1 + torch_xla/csrc/runtime/pjrt_computation_client.cc | 6 ++---- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 72010c553a49..790f6d8f31cb 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,9 +43,9 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", ], - strip_prefix = "xla-9b339c6fa10f6e964e21b58e40217661f7824bae", + strip_prefix = "xla-cd2cf5c34931e4fc1cacf83bfc480a5b93f05f6d", urls = [ - "https://github.com/openxla/xla/archive/9b339c6fa10f6e964e21b58e40217661f7824bae.tar.gz", + "https://github.com/openxla/xla/archive/cd2cf5c34931e4fc1cacf83bfc480a5b93f05f6d.tar.gz", ], ) diff --git a/bazel/rules_def.bzl b/bazel/rules_def.bzl index 0d239ce71abe..08c7d237f65c 100644 --- a/bazel/rules_def.bzl +++ b/bazel/rules_def.bzl @@ -24,7 +24,7 @@ def ptxla_cc_test( **kwargs): xla_cc_test( linkstatic = True, - extra_copts = copts + [ + copts = copts + [ "-isystemexternal/torch", # Required for system includes. "-fexceptions", # Required for testing crashes. ], diff --git a/setup.py b/setup.py index 0a921968cf49..244b9489faf6 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20230703' +_libtpu_version = '0.1.dev20230809' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 419a281691c2..88bd88eac198 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1628,12 +1628,11 @@ void InitXlaModuleBindings(py::module m) { XLA_CHECK(num_nodes > 0) << "num_nodes must be positive: " << num_nodes; - xla::DistributedRuntimeServiceImpl::Options options; + xla::CoordinationServiceImpl::Options options; options.num_nodes = num_nodes; - return std::move(xla::GetDistributedRuntimeService( - dist_service_addr, options, - /*use_coordination_service=*/false) - .value()); + return std::move( + xla::GetDistributedRuntimeService(dist_service_addr, options) + .value()); }); BuildProfilerSubmodule(&m); diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 527298ae0cd6..60649b4f6b67 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -97,6 +97,7 @@ cc_library( "@xla//xla/pjrt:tfrt_cpu_pjrt_client", "@xla//xla/pjrt:pjrt_c_api_client", "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/platform/cloud:gcs_file_system", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index b82d0844a6c1..1d6671fbfeb2 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -29,7 +29,7 @@ #include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/pjrt/tpu_client.h" #include "xla/shape.h" -#include "xla/stream_executor/tpu/tpu_initializer_helper.h" +#include "xla/stream_executor/tpu/tpu_initializer_framework_helper.h" using xla::internal::XlaBuilderFriend; @@ -49,9 +49,7 @@ MaybeInitializeDistributedRuntimeClient(int local_rank, xla::DistributedRuntimeClient::Options options; /* TODO(jonbolin): Use global rank for multi-host setup */ options.node_id = local_rank; - client = - xla::GetDistributedRuntimeClient(dist_service_addr, options, - /*use_coordination_service=*/false); + client = xla::GetDistributedRuntimeClient(dist_service_addr, options); XLA_CHECK(client->Connect().ok()) << "Failed to initialize distributed runtime client"; }