Skip to content

Commit ff5a7de

Browse files
committed
feat(mean): add downsampled loss computation to improve efficiency
1 parent 6df0bef commit ff5a7de

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

deepem/loss/mean.py

+10
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
delta_d: float = 1.5,
7979
recompute_ext: bool = False,
8080
mask_background: bool = True,
81+
loss_scale_factor: tuple[float, float, float] | None = None,
8182
**kwargs,
8283
):
8384
super().__init__()
@@ -88,6 +89,7 @@ def __init__(
8889
self.delta_d = delta_d # Distance (inter-cluster push force) hinge
8990
self.recompute_ext = recompute_ext
9091
self.mask_background = mask_background
92+
self.loss_scale_factor = loss_scale_factor
9193

9294
def forward(
9395
self,
@@ -104,6 +106,14 @@ def forward(
104106
"""
105107
device = embd.device
106108

109+
# Downsample if enabled
110+
if self.loss_scale_factor is not None:
111+
embd = F.interpolate(embd, scale_factor=self.loss_scale_factor, mode='trilinear', align_corners=False)
112+
trgt = F.interpolate(trgt, scale_factor=self.loss_scale_factor, mode='nearest')
113+
mask = F.interpolate(mask, scale_factor=self.loss_scale_factor, mode='nearest')
114+
if splt is not None:
115+
splt = F.interpolate(splt, scale_factor=self.loss_scale_factor, mode='nearest')
116+
107117
groups = None
108118
if self.recompute_ext:
109119
assert splt is not None

deepem/train/option.py

+2
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def initialize(self):
8383
self.parser.add_argument('--delta_d', type=float, default=1.5)
8484
self.parser.add_argument('--recompute_ext', action='store_true')
8585
self.parser.add_argument('--no_mask_background', action='store_true')
86+
self.parser.add_argument('--loss_scale_factor', type=vec3f, default=None)
8687

8788
# Optimizer
8889
self.parser.add_argument('--optim', default='Adam')
@@ -227,6 +228,7 @@ def parse(self):
227228
opt.metric_params['delta_d'] = opt.delta_d
228229
opt.metric_params['recompute_ext'] = opt.recompute_ext
229230
opt.metric_params['mask_background'] = not opt.no_mask_background
231+
opt.metric_params['loss_scale_factor'] = opt.loss_scale_factor
230232

231233
# Optimizer
232234
if opt.optim == 'Adam':

0 commit comments

Comments
 (0)