@@ -125,22 +125,31 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs):
125125def configure_optical_flows (img , params , loss_augs ):
126126 logger .debug (params .device )
127127 _device = params .device
128+
129+ # this shouldn't be in this function based on the name.
130+ # other loss augs
131+ if params .smoothing_weight != 0 :
132+ loss_augs .append (
133+ TVLoss (weight = params .smoothing_weight )
134+ ) # , device=params.device))
135+
128136 optical_flows = []
129137 if params .animation_mode == "Video Source" :
130138 if params .flow_stabilization_weight == "" :
131139 params .flow_stabilization_weight = "0"
132140 # TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped?
133-
134- for i in range (params .flow_long_term_samples + 1 ):
135- optical_flow = OpticalFlowLoss (
136- comp = torch .zeros (1 , 1 , 1 , 1 , device = _device ), # ,device=DEVICE)
137- weight = params .flow_stabilization_weight ,
138- name = f"optical flow stabilization (frame { - 2 ** i } ) (direct)" ,
139- image_shape = img .image_shape ,
140- device = _device ,
141- ) # , device=device)
142- optical_flow .set_enabled (False )
143- optical_flows .append (optical_flow )
141+ if params .flow_stabilization_weight != "0" :
142+ # TO DO: if weight is parameterized, need to do a parameteric evaluation here.
143+ for i in range (params .flow_long_term_samples + 1 ):
144+ optical_flow = OpticalFlowLoss (
145+ comp = torch .zeros (1 , 1 , 1 , 1 , device = _device ), # ,device=DEVICE)
146+ weight = params .flow_stabilization_weight ,
147+ name = f"optical flow stabilization (frame { - 2 ** i } ) (direct)" ,
148+ image_shape = img .image_shape ,
149+ device = _device ,
150+ ) # , device=device)
151+ optical_flow .set_enabled (False )
152+ optical_flows .append (optical_flow )
144153
145154 elif params .animation_mode == "3D" and params .flow_stabilization_weight not in [
146155 "0" ,
@@ -158,13 +167,6 @@ def configure_optical_flows(img, params, loss_augs):
158167
159168 loss_augs .extend (optical_flows )
160169
161- # this shouldn't be in this function based on the name.
162- # other loss augs
163- if params .smoothing_weight != 0 :
164- loss_augs .append (
165- TVLoss (weight = params .smoothing_weight )
166- ) # , device=params.device))
167-
168170 return img , loss_augs , optical_flows
169171
170172
0 commit comments