Skip to content

Commit b359ea7

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Wire up the plugin that contains the MGPU custom call into Bazel tests
PiperOrigin-RevId: 751445068
1 parent c0389ff commit b359ea7

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

jax_plugins/cuda/BUILD.bazel

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
licenses(["notice"])
16-
1715
load(
18-
"//jaxlib:jax.bzl",
19-
"if_windows",
20-
"py_library_providing_imports_info",
21-
"pytype_library",
16+
"//jaxlib:jax.bzl",
17+
"if_windows",
18+
"py_library_providing_imports_info",
19+
"pytype_library",
2220
)
2321

22+
licenses(["notice"])
23+
2424
package(
2525
default_applicable_licenses = [],
2626
default_visibility = ["//:__subpackages__"],
@@ -41,7 +41,7 @@ py_library_providing_imports_info(
4141
],
4242
data = if_windows(
4343
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
44-
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
44+
["//jaxlib/tools:pjrt_c_api_gpu_plugin.so"],
4545
),
4646
lib_rule = pytype_library,
4747
)

jax_plugins/cuda/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _get_library_path():
5151
runfiles_dir = os.getenv('RUNFILES_DIR', None)
5252
if runfiles_dir:
5353
local_path = os.path.join(
54-
runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so'
54+
runfiles_dir, '__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so'
5555
)
5656

5757
if os.path.exists(local_path):

0 commit comments

Comments
 (0)