@@ -63,6 +63,18 @@ def __init__(
63
63
raise ValueError ("OCDBT not supported for Pathways." )
64
64
super ().__init__ ()
65
65
66
+ async def _background_serialize (
67
+ self ,
68
+ values : Sequence [jax .Array ],
69
+ locations : Sequence [str ],
70
+ names : Sequence [str ],
71
+ ) -> None :
72
+ """Uses Pathways Persistence API to serialize a jax array."""
73
+ f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
74
+ futures_results = list (map (f , locations , names , values ))
75
+ for future_result in futures_results :
76
+ future_result .result ()
77
+
66
78
async def serialize (
67
79
self ,
68
80
values : Sequence [jax .Array ],
@@ -76,8 +88,12 @@ async def serialize(
76
88
raise ValueError ("Casting during save not supported for Pathways." )
77
89
78
90
locations , names = extract_parent_dir_and_name (infos )
79
- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
80
- return list (map (f , locations , names , values ))
91
+ return [
92
+ future .CommitFutureAwaitingContractedSignals (
93
+ self ._background_serialize (values , locations , names ),
94
+ name = "cloud_pathways_array_handler" ,
95
+ )
96
+ ]
81
97
82
98
async def deserialize (
83
99
self ,
0 commit comments