Skip to content

Commit a94d9fc

Browse files
authored
Fix for removal of JAX DeviceLocalLayout. (#127)
Type verification fails after removal.
1 parent 49d32e1 commit a94d9fc

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _create_sparsecore_distribution(
219219
LayoutClass = (
220220
jax_layout.Layout
221221
if jax.__version_info__ >= (0, 6, 3)
222-
else jax_layout.DeviceLocalLayout
222+
else jax_layout.DeviceLocalLayout # type: ignore
223223
)
224224
# pylint: disable-next=protected-access
225225
sparsecore_layout._backend_layout = jax_layout.Format(

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _create_sparsecore_layout(
4444
LayoutClass = (
4545
jax_layout.Layout
4646
if jax.__version_info__ >= (0, 6, 3)
47-
else jax_layout.DeviceLocalLayout
47+
else jax_layout.DeviceLocalLayout # type: ignore
4848
)
4949
# pylint: disable-next=protected-access
5050
sparsecore_layout._backend_layout = jax_layout.Format(

0 commit comments

Comments
 (0)