Skip to content

jaxlib.xla_extension.XlaRuntimeError #50

@Xiaojun928

Description

@Xiaojun928

钟博好~

感谢开发出Parafold这一利器!
在安装Parafold过程中暂时没有遇到问题,并且顺利完成了第一步feature。我在尝试运行第二步结构预测时,遇到jax相关的问题,有劳您帮忙给一些建议呀~

GPU配置信息如下:

 NVIDIA-SMI 550.90.12              Driver Version: 550.90.12      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H800 PCIe               Off |   00000000:34:00.0 Off |                    0 |
| N/A   52C    P0             89W /  350W |       1MiB /  81559MiB |      3%      Default |
|                                         |                        |             Disabled

安装方式参考readme中“How to install” 部分,jax 的版本也是遵循readme中提到的0.3.25版本。
另外,我还参考 issue#39 中的建议安装了cuda-nvcc,但类似的问题并未得到解决。

我遇到的报错信息如下:

2024-11-21 11:37:59.402301: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:231] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2024-11-21 11:37:59.402323: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:234] Used ptxas at ptxas
2024-11-21 11:37:59.404084: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:628] failed to get PTX kernel "shift_right_logical" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2024-11-21 11:37:59.404116: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function
Traceback (most recent call last):
  File "/home/software/ParallelFold/run_alphafold.py", line 491, in <module>
    app.run(main)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/software/ParallelFold/run_alphafold.py", line 464, in main
    predict_structure(
  File "/home/software/ParallelFold/run_alphafold.py", line 239, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict,
  File "/home/software/ParallelFold/alphafold/model/model.py", line 167, in predict
    result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/random.py", line 132, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 580, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 592, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 597, in random_seed_impl_base
    return seed(seeds)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 832, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 515, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
    return compiled_fun(*args)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 200, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Could not find the corresponding function

这似乎是H800与JAX 0.3.25不兼容,请问如果升级JAX可以吗?

多谢!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions