You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 21, 2024. It is now read-only.
I created a fresh conda environments and pip install ... 'ed the requirements.txt only to realize this had not installed GPU-compatible jax so after a little searching I then installed jax using conda install jax cuda-nvcc -c conda-forge -c nvidia as recommended by this page but then I got the following warnings telling me that I was not using GPU, which I'd like to use:
(jax_verify) chelseas@server:~/jax_verify$ python3 examples/run_boundprop.py --boundprop_method=backward_crown_bound_propagation
I1228 22:37:51.228809 140315511407680 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.
I1228 22:37:51.228915 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I1228 22:37:51.228991 140315511407680 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1228 22:37:51.229050 140315511407680 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1228 22:37:51.229235 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I1228 22:37:51.229310 140315511407680 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W1228 22:37:51.229365 140315511407680 xla_bridge.py:362] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I opened an issue in the jax repo initially but I was thinking that someone maintaining this repo might also be able to help.
The text was updated successfully, but these errors were encountered:
This seems to be a jax issue more than a jax_verify one.
Some things worth looking at:
Is Cuda installed correctly on your machine? Do you have the same problem with other frameworks (like pytorch or Tensorflow) ? Can you find the /usr/local/cuda-11.4 file on your machine?
From your other issue, it seems that you are running Cuda 11.4. Have you made sure that you installed the correct version of jaxlib?
Did you check that you activated the right conda environment?
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
I created a fresh conda environments and
pip install ...
'ed therequirements.txt
only to realize this had not installed GPU-compatiblejax
so after a little searching I then installed jax usingconda install jax cuda-nvcc -c conda-forge -c nvidia
as recommended by this page but then I got the following warnings telling me that I was not using GPU, which I'd like to use:I opened an issue in the
jax
repo initially but I was thinking that someone maintaining this repo might also be able to help.The text was updated successfully, but these errors were encountered: