Skip to content

Commit

Permalink
Bazel torchxla (pytorch#4755)
Browse files Browse the repository at this point in the history
* Migrate bazel torchxla

* Remove sndboxing for coverage execution

* Add missing files

* Remove build_torch_xla_libs mentions.

* Improve cache hits

* Dont separately require configuring test building; they are free now.

* format python

* fix run_tests.sh

* Pass test arguments via bazelrc

* Merge tests into a single target due to grpc address in use issues.

* Make testenv consistent for cache hits

* Remove abi logic, it's all in setup.py now

* Write both to log file and to output

* Update depreacated property

* add libpython to libs

* Change test filter flag

* Log file comment out for debugging

* Minimize downloads from cache

* Migrate to new bazel flag for exec propertieS

* Cache silo for CI

* set python version so that python3-config is found and used on circleci

* use ci cache silos when building

* simplify the silo flag

* improve silos

* Add conda init for tests

* format py

* hide the creds

* remove conda activation

* Setup conda library path

* Try improving conda setup

* Move the setup into bashrc

* common

* revert to old cache silo flag that allows overrides

* ormat py

* Revert to old style of specifying remote exec params

* Add bes timeout

* remove default silos key

* Rebase on updates

* pass in ld_lib_path to tests

* Propagate XLA_EXPERIMENTAL to bazel

* Support for cuda in tests

* Pass the cuda flag to cpp tests.

* remove cuda from deps of ptxla test since it's already in xla_client linked via xla_client:computation_client

* Fix multiconfiguration issues for tests

* Don't trim the tets config; test_filter remains

* Copy the codegen directory to get the source in docker

* Add libtpu to the wheel, and link accordingly

* include buildextensions; that redefines some disttools classes. python sucks.

* Update to cloud builder docker image and pass in the remote bazel flags

* Setup silo and remote cache for cloudbuild

* Set cache silo even with default creds

* fix debug flag

* Allow CXX_ABI flag to be set externally.

* Set instrumentatoin filter to avoid tests

* Document bazel

* User might be root often so make sure docs are clear

* format py

* Remove gen_lazy_tensor; now under codegen/

* Update documentation

* add coverage script

* Update docs with remote bazel role in gcp

* Update bazel docs

* Enable remote cache for bazel in ansible.

* Propagate default credentials to docker

* Remove unused rpath settings

* Upstream xla native functions

* Don't make the build DEBUG just for coverage.

* Avoid waiting for bes, which can be flaky

* Remove build-only testing

* Update xla native functions yaml

* Adjust cpp coverage stuff

* Use  remote build for building tests.

* Debug mode

* Allow building tests

* Pass the TPU config to bazel tests.
  • Loading branch information
stgpetrovic authored Apr 25, 2023
1 parent 74eff29 commit ed212d7
Show file tree
Hide file tree
Showing 272 changed files with 2,245 additions and 902 deletions.
60 changes: 48 additions & 12 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
############################################################################
# All default build options below.

# Enable exceptions in C++.
common --copt=-fexceptions

# Make Bazel print out all options from rc files.
build --announce_rc

Expand All @@ -28,9 +25,9 @@ build -c opt
build --config=short_logs

# Force GCC because clang/bazel has issues.
common --action_env=CC=gcc
common --action_env=CXX=g++
common --spawn_strategy=standalone
build --action_env=CC=gcc
build --action_env=CXX=g++
build --spawn_strategy=standalone

###########################################################################

Expand Down Expand Up @@ -63,7 +60,6 @@ build:acl --define==build_with_acl=true
build:nonccl --define=no_nccl_support=true

build:linux --config=posix
build:linux --copt=-Wno-unknown-warning-option

# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
Expand All @@ -75,6 +71,45 @@ build:tpu --define=with_tpu_support=true
# RBE config options below.
# Flag to enable remote config
common --experimental_repo_remote_exec

# Inherit environmental variables that are used in testing.
test --test_env=TPU_NUM_DEVICES --test_env=GPU_NUM_DEVICES --test_env=CPU_NUM_DEVICES --test_env=XRT_LOCAL_WORKER
test --test_env=XRT_TPU_CONFIG --test_env=XRT_DEVICE_MAP --test_env=XRT_WORKERS --test_env=XRT_MESH_SERVICE_ADDRESS
test --test_env=XRT_SHARD_WORLD_SIZE --test_env=XRT_MULTI_PROCESSING_DEVICE --test_env=XRT_HOST_ORDINAL --test_env=XRT_SHARD_ORDINAL
test --test_env=XRT_START_LOCAL_SERVER --test_env=TPUVM_MODE --test_env=PJRT_DEVICE --test_env=PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS
test --test_env=PJRT_CPU_ASYNC_CLIENT --test_env=PJRT_GPU_ASYNC_CLIENT --test_env=TPU_LIBRARY_PATH --test_env=PJRT_DIST_SERVICE_ADDR
test --test_env=PJRT_LOCAL_PROCESS_RANK

# This environmental variable is important for properly integrating with XLA.
test --test_env=XLA_EXPERIMENTAL

# To find `libpython` that is required to run tests (they run using installed wheels).
test --test_env=LD_LIBRARY_PATH

# This fixes an issue where targets are configured differently because of `test_filter`.
# https://github.com/bazelbuild/bazel/issues/6842
test --notrim_test_configuration

# Stabilize the environmental variables used to minimize cache misses (src and env affects cache keys).
build --incompatible_strict_action_env

# By default in local builds, do not upload local results to cache.
build --noremote_upload_local_results

# Remote caching with local builds.
build:remote_cache --remote_cache=grpcs://remotebuildexecution.googleapis.com
build:remote_cache --remote_instance_name=projects/tpu-pytorch/instances/default_instance
build:remote_cache --google_default_credentials
build:remote_cache --remote_upload_local_results
build:remote_cache --bes_backend=buildeventservice.googleapis.com
build:remote_cache --bes_upload_mode=fully_async
build:remote_cache --bes_results_url="https://source.cloud.google.com/results/invocations"
build:remote_cache --bes_instance_name="tpu-pytorch"
build:remote_cache --bes_timeout=600s # On longer builds, BES can cause a non-zero exit from bazel.

# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:remote_cache --remote_download_toplevel
#########################################################################

# Load rc file with user-specific options.
Expand All @@ -84,17 +119,14 @@ try-import %workspace%/.bazelrc.user
build:compdb --features=-layering_check

# Compiling tests requires Java.
common --java_runtime_version=remotejdk_11
build --java_runtime_version=remotejdk_11

# Coverage requires Java and GCC.
coverage --config=coverage
coverage --build_tests_only
build:coverage --copt=-DNDEBUG
coverage --instrumentation_filter="//torch_xla/,//third_party/"
build:coverage --combined_report=lcov
build:coverage --strategy=TestRunner=sandboxed,local
build:coverage --strategy=CoverageReport=sandboxed,local
build:coverage --experimental_use_llvm_covmap
build:coverage --collect_code_coverage
build:coverage --test_tag_filters=-nocoverage

############################################################################
Expand Down Expand Up @@ -175,3 +207,7 @@ build:linux --copt="-Wswitch"
build:linux --copt="-Werror=switch"
# Required for building with clang
build:linux --copt="-Wno-error=unused-but-set-variable"

# Only include debug info for files in this repository, excluding external deps.
build:dbg -c dbg
build:dbg --per_file_copt=+external/.*@-g0,-DNDEBUG
2 changes: 2 additions & 0 deletions .circleci/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ python setup.py install
sccache --show-stats

source $XLA_DIR/xla_env
export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json"
export SILO_NAME='cache-silo-ci' # cache bucket for CI
build_torch_xla $XLA_DIR

popd
37 changes: 14 additions & 23 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ fi
# 2. CONDA_PREFIX (if it exists)
# 3. The conda install directory (if it exists)
export CMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH:-${CONDA_PREFIX:-"$(dirname $(which conda))/../"}}
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$(python3-config --prefix)/lib"
echo $LD_LIBRARY_PATH

function clone_pytorch() {
PYTORCH_DIR=$1
Expand Down Expand Up @@ -95,21 +97,6 @@ function install_deps_pytorch_xla() {
if ls $CUBLAS_PATTERN 1> /dev/null 2>&1; then
sudo ln -s $CUBLAS_PATTERN /usr/local/cuda/include
fi

# Use cloud cache to build when available.
if [[ "$USE_CACHE" == 1 ]]; then
# Install bazels3cache for cloud cache
sudo npm install -g n
sudo n 16.18.0
sudo npm install -g bazels3cache
BAZELS3CACHE="$(which /usr/local/bin/bazels3cache)"
if [ -z "${BAZELS3CACHE}" ]; then
echo "Unable to find bazels3cache..."
return 1
fi
/usr/local/bin/bazels3cache --bucket=${XLA_CLANG_CACHE_S3_BUCKET_NAME} --maxEntrySizeBytes=0 --logging.level=verbose
sed -i '/bazel build/ a --remote_http_cache=http://localhost:7777 \\' $XLA_DIR/build_torch_xla_libs.sh
fi
}

function build_torch_xla() {
Expand Down Expand Up @@ -172,18 +159,22 @@ function run_torch_xla_tests() {

pushd test/cpp
echo "Running C++ Tests on PJRT"
EXTRA_ARGS=""
if [ "$USE_COVERAGE" != "0" ]; then
EXTRA_ARGS="-C"
fi
if [ ! -z "$GCLOUD_SERVICE_KEY_FILE" ]; then
EXTRA_ARGS="-R"
fi
if [ -x "$(command -v nvidia-smi)" ]; then
PJRT_DEVICE=GPU ./run_tests.sh
PJRT_DEVICE=GPU ./run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L""
PJRT_DEVICE=GPU ./run_tests.sh $EXTRA_ARGS
PJRT_DEVICE=GPU ./run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS
else
PJRT_DEVICE=CPU ./run_tests.sh
PJRT_DEVICE=CPU ./run_tests.sh $EXTRA_ARGS
fi
if [ "$USE_COVERAGE" != "0" ]; then
export PATH=$PATH:/usr/lib/llvm-8/bin
chmod +x /tmp/pytorch/xla/test/cpp/get_coverage.sh
lcov --directory /tmp/pytorch/xla/build/temp.linux-x86_64-cpython-38/torch_xla/csrc --base-directory . --gcov-tool /tmp/pytorch/xla/test/cpp/get_coverage.sh --capture -o cpp_lcov.info
genhtml cpp_lcov.info -o ~/htmlcov//cpp/cpp_lcov.info
mv cpp_lcov.info ~/htmlcov/cpp_lcov.info
genhtml .bazel-out/_coverage/_coverage_report.dat -o ~/htmlcov/cpp/cpp_lcov.info
mv ./.bazel-out/_coverage/_coverage_report.dat ~/htmlcov/cpp_lcov.info
fi
popd
popd
Expand Down
5 changes: 3 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ launch_docker_and_build: &launch_docker_and_build
echo "declare -x CIRCLE_PROJECT_USERNAME=${CIRCLE_PROJECT_USERNAME}" >> /home/circleci/project/env
echo "declare -x CIRCLE_PROJECT_REPONAME=${CIRCLE_PROJECT_REPONAME}" >> /home/circleci/project/env
# Set debug so that xla builds with coverage symbols
echo "declare -x DEBUG=1" >> /home/circleci/project/env
# Set up remote cache/build authentication.
echo "declare -x BAZEL_REMOTE_CACHE=1" >> /home/circleci/project/xla_env
(set +x; echo $GCLOUD_SERVICE_KEY > /home/circleci/project/default_credentials.json; set -x)
pid=$(docker run -t -d -w $WORKDIR ${GCR_DOCKER_IMAGE})
docker cp /home/circleci/project/. "$pid:$WORKDIR"
Expand Down
9 changes: 2 additions & 7 deletions .circleci/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ ENV CUDA_PATH /usr/local/cuda
ENV CC "${cc}"
ENV CXX "${cxx}"

# Whether to build torch and torch_xla libraries with CXX ABI
ENV _GLIBCXX_USE_CXX11_ABI "${cxx_abi}"
ENV CFLAGS "${CFLAGS} -D_GLIBCXX_USE_CXX11_ABI=${cxx_abi}"
ENV CXXFLAGS "${CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=${cxx_abi}"

# Whether to build for TPUVM mode
ENV TPUVM_MODE "${tpuvm}"

Expand All @@ -49,8 +44,8 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
# Install base system packages
RUN apt-get clean && apt-get update
RUN apt-get upgrade -y
RUN apt-get install --fix-missing -y python3-pip git curl libopenblas-dev vim jq \
apt-transport-https ca-certificates procps openssl sudo wget libssl-dev libc6-dbg
RUN apt-get install --fix-missing -y python-pip python3-pip git curl libopenblas-dev vim jq \
apt-transport-https ca-certificates procps openssl sudo wget libssl-dev libc6-dbg

# Install clang & llvm
ADD ./install_llvm_clang.sh install_llvm_clang.sh
Expand Down
2 changes: 2 additions & 0 deletions .circleci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ function install_torchvision() {

install_torchvision

export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json"
export SILO_NAME='cache-silo-ci' # cache bucket for CI
run_torch_xla_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ torch_xla/csrc/version.cpp
*.pyc
*.so

# Directory autogenerated by full_codegen
torch_xla/csrc/generated/

# BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
#
# Below files are not deleted by "setup.py clean".
Expand All @@ -30,3 +27,6 @@ torch_xla/csrc/generated/

# Build system temporary files
/bazel-*

# Clangd cache directory
.cache/*
8 changes: 5 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"--config=compdb",
],
"bsv.cc.compdb.targets": [
"//third_party/xla_client/...",
"//third_party/xla_client:all",
"//torch_xla/csrc:all",
"//test/cpp:all",
],
"coverage-gutters.coverageBaseDir": ".",
"coverage-gutters.showLineCoverage": false,
Expand All @@ -13,8 +15,8 @@
"./bazel-out/_coverage/_coverage_report.dat"
],
"lcov.path": [
"./.bazel-out/_coverage/_coverage_report.dat"
"./bazel-out/_coverage/_coverage_report.dat"
],
"python.formatting.provider": "yapf",
"editor.formatOnSave": true
}
}
57 changes: 57 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_shared_object",
)

tf_cc_shared_object(
name = "_XLAC.so",
copts = [
"-DTORCH_API_INCLUDE_EXTENSION_H",
"-DTORCH_EXTENSION_NAME=_XLAC",
"-fopenmp",
"-fPIC",
"-fwrapv",
],
linkopts = [
"-Wl,-rpath,$$ORIGIN/torch_xla/lib", # for libtpu
],
visibility = ["//visibility:public"],
deps = [
"//third_party/xla_client:computation_client",
"//third_party/xla_client:mesh_service",
"//third_party/xla_client:metrics",
"//third_party/xla_client:metrics_analysis",
"//third_party/xla_client:metrics_reader",
"//third_party/xla_client:multi_wait",
"//third_party/xla_client:profiler",
"//third_party/xla_client:record_reader",
"//third_party/xla_client:sys_util",
"//third_party/xla_client:thread_pool",
"//third_party/xla_client:util",
"//third_party/xla_client:xla_util",
"//torch_xla/csrc:computation",
"//torch_xla/csrc:device",
"//torch_xla/csrc:init_python_bindings",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:version",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
"@org_tensorflow//tensorflow/compiler/xla/python/profiler/internal:traceme_wrapper",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_parser",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_verifier",
"@org_tensorflow//tensorflow/compiler/xla/service:sharding_propagation",
"@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner",
"@org_tensorflow//tensorflow/core",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core/platform:env",
"@org_tensorflow//tensorflow/core/profiler/lib:traceme",
"@org_tensorflow//tensorflow/python/profiler/internal:profiler_pywrap_impl",
"@torch//:headers",
"@torch//:libc10",
"@torch//:libtorch",
"@torch//:libtorch_cpu",
"@torch//:libtorch_python",
],
)
12 changes: 6 additions & 6 deletions CODEGEN_MIGRATION_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e
- Contains all the op XLA supported today. Most of the ops are under the supported category, the goal of this document is to move most of the ops to the full_codegen category.
- xla/scripts/gen_lazy_tensor.py
- Provides necessary XLA versions of the codegen Codegen class and calls the upstream codegen API.
- xla/torch_xla/csrc/generated/XLANativeFunctions.cpp
- Result of the full_codegen column of the xla/xla_native_functions.yaml. The op function defined here will implement the op declared in the XLANativeFunctions.h. Each op will take at::tensor and return another at::tensor wrapped around a XLATensor.
- xla/torch_xla/csrc/generated/LazyIr.h
- Result of the full_codegen column of the xla/xla_native_functions.yaml. Defines the IR that is used to construct the full_codegen ops.
- xla/torch_xla/csrc/XLANativeFunctions.cpp
- Result of the full_codegen column of the xla/codegen/xla_native_functions.yaml. The op function defined here will implement the op declared in the XLANativeFunctions.h. Each op will take at::tensor and return another at::tensor wrapped around a XLATensor.
- xla/torch_xla/csrc/LazyIr.h
- Result of the full_codegen column of the xla/codegen/xla_native_functions.yaml. Defines the IR that is used to construct the full_codegen ops.

### PyTorch/XLA Old Op Lowering files
- xla/torch_xla/csrc/generated/aten_xla_type.cpp
- Manually implements ops defined in xla/xla_native_functions.yaml. Will be replaced by XLANativeFunctions.cpp
- Manually implements ops defined in xla/codegen/xla_native_functions.yaml. Will be replaced by XLANativeFunctions.cpp
- xla/torch_xla/csrc/generated/tensor.h
- Defines XLATensor class and XLATensor method declarations. These declarations are usually a one to one mapping of the at::Tensor nodes we declared in XLANativeFunctions.h. XLATensor method will be removed for full_codegen ops
- xla/torch_xla/csrc/generated/tensor_method.cpp
Expand Down Expand Up @@ -76,7 +76,7 @@ at::Tensor XLANativeFunctions::abs(const at::Tensor& self) {
```

### 2. Codegen the op and inspect the generated file
Find the op in `xla/xla_native_functions.yaml` and move it to the full_codegen column and run `python setup.py install` under xla directory again. The build will fail (reason explained later in this guide) but you can still see the generated file. The code snippets below uses `abs` as an example.
Find the op in `xla/codegen/xla_native_functions.yaml` and move it to the full_codegen column and run `python setup.py install` under xla directory again. The build will fail (reason explained later in this guide) but you can still see the generated file. The code snippets below uses `abs` as an example.
#### XLANativeFunctions.cpp
```
at::Tensor XLANativeFunctions::abs(const at::Tensor & self) {
Expand Down
2 changes: 1 addition & 1 deletion OP_LOWERING_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export PJRT_DEVICE=CPU
You can find the definition of the C++ ATen operations in [native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). After you build Pytorch/XLA from source, you will also find our default implementation (a boxed kernel which forwards calls to PyTorch native CPU) in `xla/torch_xla/csrc/aten_cpu_fallback.h/cpp`. Pytorch operations can usually be mapped to [PyTorch tensor api](https://pytorch.org/docs/stable/index.html) easily. If that is not the case searching the PyTorch native implementation under [PyTorch repo](https://github.com/pytorch/pytorch) is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in [here](https://www.tensorflow.org/xla/operation_semantics).

## File structure
All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the exception of `xla_native_functions.yaml`
All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the exception of `codegen/xla_native_functions.yaml`

1. `xla_native_functions.yaml` contains the list of all operators that are lowered. Each operator name must directly match a pytorch operator listed in [native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). This file serves as the interface to adding new xla operators, and is an input to PyTorch's [codegen machinery](https://github.com/pytorch/pytorch/blob/main/torchgen/gen_backend_stubs.py). It generates the below 3 files: `XLANativeFunctions.h`, `RegisterXLA.cpp`, and `RegisterAutogradXLA.cpp`
2. `XLANativeFunctions.h` and `aten_xla_type.cpp` are entry points of PyTorch to the pytorch_xla world, and contain the manually written lowerings to XLA for each operator. `XLANativeFunctions.h` is auto-generated through a combination of `xla_native_functions.yaml` and the PyTorch core `native_functions.yaml` file, and contains declarations for kernels that need to be defined in `aten_xla_type.cpp`. The kernels written here need to construct 'XLATensor' using the input `at::Tensor` and other parameters. The resulting `XLATensor` needs to be converted back to the `at::Tensor` before returning to the PyTorch world.
Expand Down
35 changes: 35 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

################################ Python Setup ################################

# For embedded python interpreter (libpython.so.)
http_archive(
name = "pybind11_bazel",
strip_prefix = "pybind11_bazel-fc56ce8a8b51e3dd941139d329b63ccfea1d304b",
urls = ["https://github.com/pybind/pybind11_bazel/archive/fc56ce8a8b51e3dd941139d329b63ccfea1d304b.zip"],
)

http_archive(
name = "pybind11",
build_file = "@pybind11_bazel//:pybind11.BUILD",
strip_prefix = "pybind11-442261da585536521ff459b1457b2904895f23b4",
urls = ["https://github.com/pybind/pybind11/archive/442261da585536521ff459b1457b2904895f23b4.tar.gz"],
)

load("@pybind11_bazel//:python_configure.bzl", "python_configure")

# This is required for setting up the linkopts for -lpython.q
python_configure(
name = "local_config_python",
python_version = "3", # required to use `python3-config`
)
############################# TensorFlow Setup ###############################

# To update TensorFlow to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
Expand Down Expand Up @@ -60,3 +85,13 @@ tf_workspace1()
load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")

tf_workspace0()

################################ PyTorch Setup ################################

load("//bazel:dependencies.bzl", "PYTORCH_LOCAL_DIR")

new_local_repository(
name = "torch",
build_file = "//bazel:torch.BUILD",
path = PYTORCH_LOCAL_DIR,
)
Loading

0 comments on commit ed212d7

Please sign in to comment.