Skip to content

Commit 4b38c2a

Browse files
committed
feat(Zettaset): enable default volume for user-specified annotations
1 parent 4cac781 commit 4b38c2a

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

deepem/data/dataset/multi_zettaset.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def load_sample(
130130
padding: tuple[int, int, int] = (0, 0, 0),
131131
no_mask: bool = False,
132132
zettaset_lookup: dict[str, str] | None = None,
133+
zettaset_default: dict[str, str] = {},
133134
zettaset_resolution: tuple[int, int, int] | None = None,
134135
requires_binarize: list[str] = [],
135136
**kwargs
@@ -168,8 +169,17 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
168169
for name, key in zettaset_lookup.items():
169170

170171
# Annotation
171-
vol = sample.read(key)[key]
172-
dset[name] = convert_array(vol)
172+
if key in sample.annotation_names:
173+
vol = sample.read(key)[key]
174+
dset[name] = convert_array(vol)
175+
else:
176+
if key in zettaset_default:
177+
shape = tuple(map(int, bbox.size()))
178+
dtype = zettaset_default[key]
179+
vol = np.zeros(shape, dtype=dtype)
180+
dset[name] = np.transpose(vol, (2, 1, 0))
181+
else:
182+
raise KeyError(f"Annotation '{key}' not found.")
173183
anno_log = f"\t{name}: {dset[name].shape}"
174184

175185
# Binarize

deepem/train/option.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def initialize(self):
2828
self.parser.add_argument('--zettaset_path', type=str, default=[], nargs='+')
2929
self.parser.add_argument('--zettaset_specs', type=json.loads, default={})
3030
self.parser.add_argument('--zettaset_lookup', type=json.loads, default=None)
31+
self.parser.add_argument('--zettaset_default', type=json.loads, default={})
3132
self.parser.add_argument('--zettaset_padding', type=vec3, default=(0, 0, 0))
3233
self.parser.add_argument('--zettaset_padding_spec', type=json.loads, default={})
3334
self.parser.add_argument('--zettaset_resolution', type=vec3f, default=None)
@@ -314,6 +315,7 @@ def parse(self):
314315
class_keys=class_keys,
315316
glia_mask=opt.glia_mask,
316317
zettaset_lookup=opt.zettaset_lookup,
318+
zettaset_default=opt.zettaset_default,
317319
zettaset_padding=opt.zettaset_padding,
318320
zettaset_padding_spec=opt.zettaset_padding_spec,
319321
zettaset_resolution=opt.zettaset_resolution,

0 commit comments

Comments
 (0)