Skip to content

Jax module errors while following quick start guide #5

@mwishoff

Description

@mwishoff

OS: Ubuntu-22.04.3
python: 3.10.12

Followed the quick start guide and installed AlphaStar in a python virtual environment using 'pip install .' from the main directory where AlphaStar and setup.py is located. I then went to the unplugged quickstart and tried to train with dummy data and got the following jax error.

Traceback (most recent call last):
File "/home/matt/dev/alphastar/alphastar/unplugged/scripts/train.py", line 56, in
from acme.jax import utils
File "/home/matt/dev/rocket/lib/python3.10/site-packages/acme/jax/utils.py", line 190, in
devices: Optional[Sequence[jax.xla.Device]] = None,
File "/home/matt/dev/rocket/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'

Jax versions:
jax-0.4.14
jaxlib-0.4.14.

I've tried with a previous version that another user posted about setting them to both 0.3.2, but that did not work.
I've also tried decrementing the python version to 3.9 since that's what's recommended but that did not work.

Any help on what I can try would be much appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions