Skip to content

Commit 0be5caa

Browse files
committed
fix: apply correct weights for static class balancing
1 parent 439f7b3 commit 0be5caa

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

deepem/loss/affinity.py

-13
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,6 @@ def forward(self, preds, targets, masks):
5656

5757
return loss, nmsk
5858

59-
# def class_balancing(self, target, mask):
60-
# if not self.balancing:
61-
# return mask
62-
# dtype = mask.type()
63-
# m_int = mask * torch.eq(target, 1).type(dtype)
64-
# m_ext = mask * torch.eq(target, 0).type(dtype)
65-
# n_int = m_int.sum().item()
66-
# n_ext = m_ext.sum().item()
67-
# if n_int > 0 and n_ext > 0:
68-
# m_int *= n_ext/(n_int + n_ext)
69-
# m_ext *= n_int/(n_int + n_ext)
70-
# return (m_int + m_ext).type(dtype)
71-
7259

7360
class AffinityLoss(nn.Module):
7461
def __init__(self, edges, criterion, size_average=False,

deepem/train/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def get_criteria(opt):
2020
weight1=opt.class_weight1,
2121
) if opt.class_balancing else None
2222

23+
is_dynamic = (opt.class_weight0 is None) and (opt.class_weight1 is None)
24+
2325
for k in opt.out_spec:
2426
if k == 'affinity' or k == 'long_range':
2527
if k == 'affinity':
@@ -42,7 +44,14 @@ def get_criteria(opt):
4244
params['margin0'] = 0
4345
params['margin1'] = 0
4446
params['inverse'] = False
45-
params['class_balancer'] = balancer
47+
params['class_balancer'] = (
48+
balancer
49+
if is_dynamic
50+
else BinaryWeightBalancer(
51+
weight0=opt.class_weight1,
52+
weight1=opt.class_weight0,
53+
)
54+
)
4655
criteria[k] = getattr(loss, 'BCELoss')(**params)
4756
return criteria
4857

0 commit comments

Comments
 (0)