Skip to content

Commit 4c3ae03

Browse files
committed
feat(Zettaset): allow sharing of a user-specified mask volume across different annotations
1 parent 4cac781 commit 4c3ae03

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

deepem/data/dataset/multi_zettaset.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def load_sample(
132132
zettaset_lookup: dict[str, str] | None = None,
133133
zettaset_resolution: tuple[int, int, int] | None = None,
134134
requires_binarize: list[str] = [],
135+
zettaset_share_mask: str | None = None,
135136
**kwargs
136137
) -> dict[str, np.ndarray]:
137138
"""Load image and labels from a Sample."""
@@ -164,6 +165,16 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
164165
# Assumes that zettaset's annotation names follow DeepEM's convention.
165166
zettaset_lookup = zettaset_lookup or {x: x for x in sample.annotation_names}
166167

168+
# Shared mask
169+
shared_mask = None
170+
if (not no_mask) and zettaset_share_mask:
171+
key = zettaset_share_mask
172+
mask_key = f"{zettaset_share_mask}_mask"
173+
if key not in sample.masks:
174+
raise KeyError(f"Mask '{mask_key}' not found.")
175+
mask_vol = sample.read_mask(key)[key]
176+
shared_mask = convert_array(mask_vol).astype("uint8")
177+
167178
# Process annotations
168179
for name, key in zettaset_lookup.items():
169180

@@ -178,7 +189,9 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
178189

179190
# Mask
180191
mask_key = f"{name}_mask"
181-
if (not no_mask) and (key in sample.masks):
192+
if shared_mask is not None:
193+
dset[mask_key] = shared_mask
194+
elif (not no_mask) and (key in sample.masks):
182195
mask_vol = sample.read_mask(key)[key]
183196
dset[mask_key] = convert_array(mask_vol).astype("uint8")
184197
else:

deepem/train/option.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def initialize(self):
3232
self.parser.add_argument('--zettaset_padding_spec', type=json.loads, default={})
3333
self.parser.add_argument('--zettaset_resolution', type=vec3f, default=None)
3434
self.parser.add_argument('--zettaset_no_mask', action='store_true')
35+
self.parser.add_argument('--zettaset_share_mask', type=str, default=None)
3536

3637
# file synchronization for spot/preemptible training
3738
self.parser.add_argument('--samwise_map', nargs='*', default=None)
@@ -319,6 +320,7 @@ def parse(self):
319320
zettaset_resolution=opt.zettaset_resolution,
320321
zettaset_mask=not opt.zettaset_no_mask,
321322
requires_binarize=requires_binarize,
323+
zettaset_share_mask=opt.zettaset_share_mask,
322324
)
323325

324326
# ONNX

0 commit comments

Comments
 (0)