@@ -125,22 +125,31 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs):
125
125
def configure_optical_flows (img , params , loss_augs ):
126
126
logger .debug (params .device )
127
127
_device = params .device
128
+
129
+ # this shouldn't be in this function based on the name.
130
+ # other loss augs
131
+ if params .smoothing_weight != 0 :
132
+ loss_augs .append (
133
+ TVLoss (weight = params .smoothing_weight )
134
+ ) # , device=params.device))
135
+
128
136
optical_flows = []
129
137
if params .animation_mode == "Video Source" :
130
138
if params .flow_stabilization_weight == "" :
131
139
params .flow_stabilization_weight = "0"
132
140
# TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped?
133
-
134
- for i in range (params .flow_long_term_samples + 1 ):
135
- optical_flow = OpticalFlowLoss (
136
- comp = torch .zeros (1 , 1 , 1 , 1 , device = _device ), # ,device=DEVICE)
137
- weight = params .flow_stabilization_weight ,
138
- name = f"optical flow stabilization (frame { - 2 ** i } ) (direct)" ,
139
- image_shape = img .image_shape ,
140
- device = _device ,
141
- ) # , device=device)
142
- optical_flow .set_enabled (False )
143
- optical_flows .append (optical_flow )
141
+ if params .flow_stabilization_weight != "0" :
142
+ # TO DO: if weight is parameterized, need to do a parameteric evaluation here.
143
+ for i in range (params .flow_long_term_samples + 1 ):
144
+ optical_flow = OpticalFlowLoss (
145
+ comp = torch .zeros (1 , 1 , 1 , 1 , device = _device ), # ,device=DEVICE)
146
+ weight = params .flow_stabilization_weight ,
147
+ name = f"optical flow stabilization (frame { - 2 ** i } ) (direct)" ,
148
+ image_shape = img .image_shape ,
149
+ device = _device ,
150
+ ) # , device=device)
151
+ optical_flow .set_enabled (False )
152
+ optical_flows .append (optical_flow )
144
153
145
154
elif params .animation_mode == "3D" and params .flow_stabilization_weight not in [
146
155
"0" ,
@@ -158,13 +167,6 @@ def configure_optical_flows(img, params, loss_augs):
158
167
159
168
loss_augs .extend (optical_flows )
160
169
161
- # this shouldn't be in this function based on the name.
162
- # other loss augs
163
- if params .smoothing_weight != 0 :
164
- loss_augs .append (
165
- TVLoss (weight = params .smoothing_weight )
166
- ) # , device=params.device))
167
-
168
170
return img , loss_augs , optical_flows
169
171
170
172
0 commit comments