Skip to content

Commit d0bf637

Browse files
committed
reorganized configure_optical_flows to not initialize for null flow weight
1 parent cdda5c0 commit d0bf637

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

src/pytti/LossAug/LossOrchestratorClass.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,31 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs):
125125
def configure_optical_flows(img, params, loss_augs):
126126
logger.debug(params.device)
127127
_device = params.device
128+
129+
# this shouldn't be in this function based on the name.
130+
# other loss augs
131+
if params.smoothing_weight != 0:
132+
loss_augs.append(
133+
TVLoss(weight=params.smoothing_weight)
134+
) # , device=params.device))
135+
128136
optical_flows = []
129137
if params.animation_mode == "Video Source":
130138
if params.flow_stabilization_weight == "":
131139
params.flow_stabilization_weight = "0"
132140
# TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped?
133-
134-
for i in range(params.flow_long_term_samples + 1):
135-
optical_flow = OpticalFlowLoss(
136-
comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE)
137-
weight=params.flow_stabilization_weight,
138-
name=f"optical flow stabilization (frame {-2**i}) (direct)",
139-
image_shape=img.image_shape,
140-
device=_device,
141-
) # , device=device)
142-
optical_flow.set_enabled(False)
143-
optical_flows.append(optical_flow)
141+
if params.flow_stabilization_weight != "0":
142+
# TO DO: if weight is parameterized, need to do a parameteric evaluation here.
143+
for i in range(params.flow_long_term_samples + 1):
144+
optical_flow = OpticalFlowLoss(
145+
comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE)
146+
weight=params.flow_stabilization_weight,
147+
name=f"optical flow stabilization (frame {-2**i}) (direct)",
148+
image_shape=img.image_shape,
149+
device=_device,
150+
) # , device=device)
151+
optical_flow.set_enabled(False)
152+
optical_flows.append(optical_flow)
144153

145154
elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [
146155
"0",
@@ -158,13 +167,6 @@ def configure_optical_flows(img, params, loss_augs):
158167

159168
loss_augs.extend(optical_flows)
160169

161-
# this shouldn't be in this function based on the name.
162-
# other loss augs
163-
if params.smoothing_weight != 0:
164-
loss_augs.append(
165-
TVLoss(weight=params.smoothing_weight)
166-
) # , device=params.device))
167-
168170
return img, loss_augs, optical_flows
169171

170172

0 commit comments

Comments
 (0)