-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Description
钟博好~
感谢开发出Parafold这一利器!
在安装Parafold过程中暂时没有遇到问题,并且顺利完成了第一步feature。我在尝试运行第二步结构预测时,遇到jax相关的问题,有劳您帮忙给一些建议呀~
GPU配置信息如下:
NVIDIA-SMI 550.90.12 Driver Version: 550.90.12 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H800 PCIe Off | 00000000:34:00.0 Off | 0 |
| N/A 52C P0 89W / 350W | 1MiB / 81559MiB | 3% Default |
| | | Disabled
安装方式参考readme中“How to install” 部分,jax 的版本也是遵循readme中提到的0.3.25版本。
另外,我还参考 issue#39 中的建议安装了cuda-nvcc,但类似的问题并未得到解决。
我遇到的报错信息如下:
2024-11-21 11:37:59.402301: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:231] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2024-11-21 11:37:59.402323: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:234] Used ptxas at ptxas
2024-11-21 11:37:59.404084: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:628] failed to get PTX kernel "shift_right_logical" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2024-11-21 11:37:59.404116: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function
Traceback (most recent call last):
File "/home/software/ParallelFold/run_alphafold.py", line 491, in <module>
app.run(main)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/software/ParallelFold/run_alphafold.py", line 464, in main
predict_structure(
File "/home/software/ParallelFold/run_alphafold.py", line 239, in predict_structure
prediction_result = model_runner.predict(processed_feature_dict,
File "/home/software/ParallelFold/alphafold/model/model.py", line 167, in predict
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/random.py", line 132, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 580, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 592, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 597, in random_seed_impl_base
return seed(seeds)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 832, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 515, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
return compiled_fun(*args)
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 200, in <lambda>
return lambda *args, **kw: compiled(*args, **kw)[0]
File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Could not find the corresponding function
这似乎是H800与JAX 0.3.25不兼容,请问如果升级JAX可以吗?
多谢!
Metadata
Metadata
Assignees
Labels
No labels