File tree Expand file tree Collapse file tree 4 files changed +7
-7
lines changed
keras_rs/src/layers/embedding/jax Expand file tree Collapse file tree 4 files changed +7
-7
lines changed Original file line number Diff line number Diff line change 36
36
def _get_partition_spec (
37
37
layout : (
38
38
keras .distribution .TensorLayout
39
- | jax_layout .Layout
39
+ | jax_layout .Format
40
40
| jax .sharding .NamedSharding
41
41
| jax .sharding .PartitionSpec
42
42
),
@@ -45,7 +45,7 @@ def _get_partition_spec(
45
45
if isinstance (layout , keras .distribution .TensorLayout ):
46
46
layout = layout .backend_layout
47
47
48
- if isinstance (layout , jax_layout .Layout ):
48
+ if isinstance (layout , jax_layout .Format ):
49
49
layout = layout .sharding
50
50
51
51
if isinstance (layout , jax .sharding .NamedSharding ):
@@ -217,7 +217,7 @@ def _create_sparsecore_distribution(
217
217
sparsecore_layout = keras .distribution .TensorLayout (axes , device_mesh )
218
218
# Custom sparsecore layout with tiling.
219
219
# pylint: disable-next=protected-access
220
- sparsecore_layout ._backend_layout = jax_layout .Layout (
220
+ sparsecore_layout ._backend_layout = jax_layout .Format (
221
221
jax_layout .DeviceLocalLayout (
222
222
major_to_minor = (0 , 1 ),
223
223
_tiling = ((8 ,),),
Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ def _create_sparsecore_layout(
41
41
)
42
42
sparsecore_layout = keras .distribution .TensorLayout (axes , device_mesh )
43
43
# Custom sparsecore layout with tiling.
44
- sparsecore_layout ._backend_layout = jax_layout .Layout ( # pylint: disable=protected-access
44
+ sparsecore_layout ._backend_layout = jax_layout .Format ( # pylint: disable=protected-access
45
45
jax_layout .DeviceLocalLayout (
46
46
major_to_minor = (0 , 1 ),
47
47
_tiling = ((8 ,),),
Original file line number Diff line number Diff line change 8
8
9
9
import jax
10
10
import numpy as np
11
- from jax .experimental import layout
11
+ from jax .experimental import layout as jax_layout
12
12
from jax_tpu_embedding .sparsecore .lib .nn import embedding
13
13
from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
14
14
from jax_tpu_embedding .sparsecore .utils import utils as jte_utils
20
20
shard_map = jax .experimental .shard_map .shard_map # type: ignore[attr-defined]
21
21
22
22
ArrayLike : TypeAlias = jax .Array | np .ndarray [Any , Any ]
23
- JaxLayout : TypeAlias = jax .sharding .NamedSharding | layout . Layout
23
+ JaxLayout : TypeAlias = jax .sharding .NamedSharding | jax_layout . Format
24
24
25
25
26
26
class EmbeddingLookupConfiguration :
Original file line number Diff line number Diff line change @@ -8,7 +8,7 @@ torch>=2.1.0
8
8
# Jax with cuda support.
9
9
# Keep same version as Keras repo.
10
10
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
11
- jax[cuda12_pip]==0.6.0
11
+ jax[cuda12_pip]==0.6.2
12
12
13
13
# Support for large embeddings.
14
14
jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'
You can’t perform that action at this time.
0 commit comments