Skip to content

Commit 1ed84ae

Browse files
authored
Update references of jax_layout.Layout to jax_layout.Format. (#124)
This reflects an upstream change in JAX.
1 parent ed273b1 commit 1ed84ae

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
def _get_partition_spec(
3737
layout: (
3838
keras.distribution.TensorLayout
39-
| jax_layout.Layout
39+
| jax_layout.Format
4040
| jax.sharding.NamedSharding
4141
| jax.sharding.PartitionSpec
4242
),
@@ -45,7 +45,7 @@ def _get_partition_spec(
4545
if isinstance(layout, keras.distribution.TensorLayout):
4646
layout = layout.backend_layout
4747

48-
if isinstance(layout, jax_layout.Layout):
48+
if isinstance(layout, jax_layout.Format):
4949
layout = layout.sharding
5050

5151
if isinstance(layout, jax.sharding.NamedSharding):
@@ -217,7 +217,7 @@ def _create_sparsecore_distribution(
217217
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
218218
# Custom sparsecore layout with tiling.
219219
# pylint: disable-next=protected-access
220-
sparsecore_layout._backend_layout = jax_layout.Layout(
220+
sparsecore_layout._backend_layout = jax_layout.Format(
221221
jax_layout.DeviceLocalLayout(
222222
major_to_minor=(0, 1),
223223
_tiling=((8,),),

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _create_sparsecore_layout(
4141
)
4242
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
4343
# Custom sparsecore layout with tiling.
44-
sparsecore_layout._backend_layout = jax_layout.Layout( # pylint: disable=protected-access
44+
sparsecore_layout._backend_layout = jax_layout.Format( # pylint: disable=protected-access
4545
jax_layout.DeviceLocalLayout(
4646
major_to_minor=(0, 1),
4747
_tiling=((8,),),

keras_rs/src/layers/embedding/jax/embedding_lookup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import jax
1010
import numpy as np
11-
from jax.experimental import layout
11+
from jax.experimental import layout as jax_layout
1212
from jax_tpu_embedding.sparsecore.lib.nn import embedding
1313
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
1414
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
@@ -20,7 +20,7 @@
2020
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
2121

2222
ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
23-
JaxLayout: TypeAlias = jax.sharding.NamedSharding | layout.Layout
23+
JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
2424

2525

2626
class EmbeddingLookupConfiguration:

requirements-jax-cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ torch>=2.1.0
88
# Jax with cuda support.
99
# Keep same version as Keras repo.
1010
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
11-
jax[cuda12_pip]==0.6.0
11+
jax[cuda12_pip]==0.6.2
1212

1313
# Support for large embeddings.
1414
jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'

0 commit comments

Comments
 (0)