Skip to content

Commit 89c4651

Browse files
mridul-sahucopybara-github
authored andcommitted
Make CloudPathwaysArrayHandler compatible with async directory creation feature in orbax.
PiperOrigin-RevId: 751089745
1 parent c636e39 commit 89c4651

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

pathwaysutils/persistence/orbax_handler.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@ def __init__(
6363
raise ValueError("OCDBT not supported for Pathways.")
6464
super().__init__()
6565

66+
async def _background_serialize(
67+
self,
68+
values: Sequence[jax.Array],
69+
locations: Sequence[str],
70+
names: Sequence[str],
71+
) -> None:
72+
"""Uses Pathways Persistence API to serialize a jax array."""
73+
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
74+
futures_results = list(map(f, locations, names, values))
75+
for future_result in futures_results:
76+
future_result.result()
77+
6678
async def serialize(
6779
self,
6880
values: Sequence[jax.Array],
@@ -76,8 +88,12 @@ async def serialize(
7688
raise ValueError("Casting during save not supported for Pathways.")
7789

7890
locations, names = extract_parent_dir_and_name(infos)
79-
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
80-
return list(map(f, locations, names, values))
91+
return [
92+
future.CommitFutureAwaitingContractedSignals(
93+
self._background_serialize(values, locations, names),
94+
name="cloud_pathways_array_handler",
95+
)
96+
]
8197

8298
async def deserialize(
8399
self,

0 commit comments

Comments
 (0)