Skip to content

Commit 439f7b3

Browse files
committed
feat: support semantic segmentation
1 parent 69775b4 commit 439f7b3

File tree

3 files changed

+105
-7
lines changed

3 files changed

+105
-7
lines changed

deepem/data/dataset/multi_zettaset.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def load_sample(
140140
zettaset_lookup: dict[str, str] | None = None,
141141
requires_binarize: list[str] = [],
142142
zettaset_share_mask: str | None = None,
143+
semantic_mapping: dict[str, int] = {},
143144
**kwargs
144145
) -> dict[str, np.ndarray]:
145146
"""Load image and labels from a Sample."""
@@ -190,8 +191,10 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
190191
dset[name] = convert_array(vol)
191192
anno_log = f"\t{name}: {dset[name].shape}"
192193

193-
# Binarize
194-
if name in requires_binarize:
194+
# Semantic mapping or binarize
195+
if name in semantic_mapping:
196+
dset[name] = (dset[name] == semantic_mapping[name]).astype("uint8")
197+
elif name in requires_binarize:
195198
dset[name] = (dset[name] > 0).astype("uint8")
196199

197200
# Mask

deepem/test/option.py

+69-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
from collections import OrderedDict
23
import json
34
import math
45
import os
@@ -63,11 +64,20 @@ def initialize(self):
6364
self.parser.add_argument('--mye', action='store_true')
6465
self.parser.add_argument('--mye_thresh', type=float, default=0.5)
6566
self.parser.add_argument('--blv', action='store_true')
66-
self.parser.add_argument('--blv_num_channels', type=int, default=2)
67-
self.parser.add_argument('--glia', action='store_true')
67+
self.parser.add_argument('--blv_num_channels', type=int, default=1)
68+
self.parser.add_argument('--glia', action='store_true')
6869
self.parser.add_argument('--sem', action='store_true')
6970
self.parser.add_argument('--img', action='store_true')
7071

72+
# Semantic segmentation
73+
self.parser.add_argument('--semantic', action='store_true')
74+
self.parser.add_argument('--dend', action='store_true') # Dendrite
75+
self.parser.add_argument('--axon', action='store_true') # Axon
76+
self.parser.add_argument('--soma', action='store_true') # Soma
77+
self.parser.add_argument('--nucl', action='store_true') # Nucleus
78+
self.parser.add_argument('--ecs', action='store_true') # Extracellular space
79+
self.parser.add_argument('--other', action='store_true') # Other class
80+
7181
# Test-time augmentation
7282
self.parser.add_argument('--test_aug', type=int, default=None, nargs='+')
7383
self.parser.add_argument('--test_aug16', action='store_true')
@@ -203,6 +213,35 @@ def parse(self):
203213
opt.out_spec['bvessel'] = (1,) + opt.outputsz
204214
if opt.img:
205215
opt.out_spec['image'] = (1,) + opt.outputsz
216+
if opt.dend:
217+
opt.out_spec['dendrite'] = (1,) + opt.outputsz
218+
if opt.axon:
219+
opt.out_spec['axon'] = (1,) + opt.outputsz
220+
if opt.soma:
221+
opt.out_spec['soma'] = (1,) + opt.outputsz
222+
if opt.nucl:
223+
opt.out_spec['nucleus'] = (1,) + opt.outputsz
224+
if opt.ecs:
225+
opt.out_spec['extracellular_space'] = (1,) + opt.outputsz
226+
if opt.other:
227+
opt.out_spec['other_class'] = (1,) + opt.outputsz
228+
229+
# Semantic segmentation
230+
if opt.semantic:
231+
required_keys = ['soma', 'axon', 'dendrite', 'glia', 'blood_vessel']
232+
233+
# Ensure all required keys are present in the opt.out_spec
234+
assert all(key in opt.out_spec for key in required_keys)
235+
236+
# Use OrderedDict to maintain order of required keys followed by other keys
237+
out_spec_new = OrderedDict((key, opt.out_spec[key]) for key in required_keys)
238+
239+
# Add remaining keys to out_spec_new
240+
out_spec_new.update((key, opt.out_spec[key]) for key in opt.out_spec if key not in required_keys)
241+
242+
# Convert back to standard dict if necessary
243+
opt.out_spec = dict(out_spec_new)
244+
206245
assert(len(opt.out_spec) > 0)
207246

208247
# Scan spec
@@ -240,6 +279,34 @@ def parse(self):
240279
opt.scan_spec['bvessel'] = (1,) + opt.outputsz
241280
if opt.img:
242281
opt.scan_spec['image'] = (1,) + opt.outputsz
282+
if opt.dend:
283+
opt.scan_spec['dendrite'] = (1,) + opt.outputsz
284+
if opt.axon:
285+
opt.scan_spec['axon'] = (1,) + opt.outputsz
286+
if opt.soma:
287+
opt.scan_spec['soma'] = (1,) + opt.outputsz
288+
if opt.nucl:
289+
opt.scan_spec['nucleus'] = (1,) + opt.outputsz
290+
if opt.ecs:
291+
opt.scan_spec['extracellular_space'] = (1,) + opt.outputsz
292+
if opt.other:
293+
opt.scan_spec['other_class'] = (1,) + opt.outputsz
294+
295+
# Semantic segmentation
296+
if opt.semantic:
297+
required_keys = ['soma', 'axon', 'dendrite', 'glia', 'blood_vessel']
298+
299+
# Ensure all required keys are present in the opt.scan_spec
300+
assert all(key in opt.scan_spec for key in required_keys)
301+
302+
# Use OrderedDict to maintain order of required keys followed by other keys
303+
scan_spec_new = OrderedDict((key, opt.scan_spec[key]) for key in required_keys)
304+
305+
# Add remaining keys to scan_spec_new
306+
scan_spec_new.update((key, opt.scan_spec[key]) for key in opt.scan_spec if key not in required_keys)
307+
308+
# Convert back to standard dict if necessary
309+
opt.scan_spec = dict(scan_spec_new)
243310

244311
# Test-time augmentation
245312
if opt.test_aug16:

deepem/train/option.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,20 @@ def initialize(self):
144144
self.parser.add_argument('--mye', type=float, default=0) # Myelin
145145
self.parser.add_argument('--fld', type=float, default=0) # Fold
146146
self.parser.add_argument('--blv', type=float, default=0) # Blood vessel
147-
self.parser.add_argument('--blv_num_channels', type=int, default=2)
148-
self.parser.add_argument('--glia', type=float, default=0) # Glia
147+
self.parser.add_argument('--blv_num_channels', type=int, default=1)
148+
self.parser.add_argument('--glia', type=float, default=0) # Glia
149149
self.parser.add_argument('--glia_mask', action='store_true')
150-
self.parser.add_argument('--soma', type=float, default=0) # Soma
151150
self.parser.add_argument('--img', type=float, default=0) # Image
152151

152+
# Semantic segmentation
153+
self.parser.add_argument('--sem', action='store_true')
154+
self.parser.add_argument('--dend', type=float, default=0) # Dendrite
155+
self.parser.add_argument('--axon', type=float, default=0) # Axon
156+
self.parser.add_argument('--soma', type=float, default=0) # Soma
157+
self.parser.add_argument('--nucl', type=float, default=0) # Nucleus
158+
self.parser.add_argument('--ecs', type=float, default=0) # Extracellular space
159+
self.parser.add_argument('--other', type=float, default=0) # Other class
160+
153161
# Metric learning
154162
self.parser.add_argument('--vec', type=float, default=0)
155163
self.parser.add_argument('--embed_dim', type=int, default=12)
@@ -287,6 +295,22 @@ def parse(self):
287295
'soma': ('soma', 1),
288296
'img': ('image', 1),
289297
'vec': ('embedding', opt.embed_dim),
298+
'dend': ('dendrite', 1),
299+
'axon': ('axon', 1),
300+
'nucl': ('nucleus', 1),
301+
'ecs': ('extracellular_space', 1),
302+
'other': ('other_class', 1),
303+
}
304+
305+
semantic_mapping = {
306+
'dendrite': 1,
307+
'axon': 2,
308+
'soma': 3,
309+
'nucleus': 4,
310+
'glia': 5,
311+
'extracellular_space': 6,
312+
'blood_vessel': 7,
313+
'other_class': 10,
290314
}
291315

292316
requires_binarize = [
@@ -301,6 +325,9 @@ def parse(self):
301325
if opt.blv_num_channels == 1:
302326
requires_binarize.append("blood_vessel")
303327

328+
if opt.sem:
329+
requires_binarize = [x for x in requires_binarize if x not in semantic_mapping]
330+
304331
# Test training
305332
if opt.test:
306333
opt.eval_intv = 100
@@ -329,6 +356,7 @@ def parse(self):
329356
zettaset_mask=not opt.zettaset_no_mask,
330357
requires_binarize=requires_binarize,
331358
zettaset_share_mask=opt.zettaset_share_mask,
359+
semantic_mapping=semantic_mapping if opt.sem else {},
332360
)
333361

334362
# ONNX

0 commit comments

Comments
 (0)