@@ -132,6 +132,7 @@ def load_sample(
132
132
zettaset_lookup : dict [str , str ] | None = None ,
133
133
zettaset_resolution : tuple [int , int , int ] | None = None ,
134
134
requires_binarize : list [str ] = [],
135
+ zettaset_share_mask : str | None = None ,
135
136
** kwargs
136
137
) -> dict [str , np .ndarray ]:
137
138
"""Load image and labels from a Sample."""
@@ -164,6 +165,16 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
164
165
# Assumes that zettaset's annotation names follow DeepEM's convention.
165
166
zettaset_lookup = zettaset_lookup or {x : x for x in sample .annotation_names }
166
167
168
+ # Shared mask
169
+ shared_mask = None
170
+ if (not no_mask ) and zettaset_share_mask :
171
+ key = zettaset_share_mask
172
+ mask_key = f"{ zettaset_share_mask } _mask"
173
+ if key not in sample .masks :
174
+ raise KeyError (f"Mask '{ mask_key } ' not found." )
175
+ mask_vol = sample .read_mask (key )[key ]
176
+ shared_mask = convert_array (mask_vol ).astype ("uint8" )
177
+
167
178
# Process annotations
168
179
for name , key in zettaset_lookup .items ():
169
180
@@ -178,7 +189,9 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
178
189
179
190
# Mask
180
191
mask_key = f"{ name } _mask"
181
- if (not no_mask ) and (key in sample .masks ):
192
+ if shared_mask is not None :
193
+ dset [mask_key ] = shared_mask
194
+ elif (not no_mask ) and (key in sample .masks ):
182
195
mask_vol = sample .read_mask (key )[key ]
183
196
dset [mask_key ] = convert_array (mask_vol ).astype ("uint8" )
184
197
else :
0 commit comments