@@ -227,13 +227,17 @@ def forward(
227
227
# Prepare text embeddings for spatial block
228
228
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
229
229
encoder_hidden_states = self .caption_projection (encoder_hidden_states ) # 3 120 1152
230
- encoder_hidden_states_spatial = encoder_hidden_states .repeat_interleave (num_frame , dim = 0 ). view (
231
- - 1 , encoder_hidden_states . shape [ - 2 ], encoder_hidden_states .shape [- 1 ]
232
- )
230
+ encoder_hidden_states_spatial = encoder_hidden_states .repeat_interleave (
231
+ num_frame , dim = 0 , output_size = encoder_hidden_states .shape [0 ] * num_frame
232
+ ). view ( - 1 , encoder_hidden_states . shape [ - 2 ], encoder_hidden_states . shape [ - 1 ])
233
233
234
234
# Prepare timesteps for spatial and temporal block
235
- timestep_spatial = timestep .repeat_interleave (num_frame , dim = 0 ).view (- 1 , timestep .shape [- 1 ])
236
- timestep_temp = timestep .repeat_interleave (num_patches , dim = 0 ).view (- 1 , timestep .shape [- 1 ])
235
+ timestep_spatial = timestep .repeat_interleave (
236
+ num_frame , dim = 0 , output_size = timestep .shape [0 ] * num_frame
237
+ ).view (- 1 , timestep .shape [- 1 ])
238
+ timestep_temp = timestep .repeat_interleave (
239
+ num_patches , dim = 0 , output_size = timestep .shape [0 ] * num_patches
240
+ ).view (- 1 , timestep .shape [- 1 ])
237
241
238
242
# Spatial and temporal transformer blocks
239
243
for i , (spatial_block , temp_block ) in enumerate (
@@ -299,7 +303,9 @@ def forward(
299
303
).permute (0 , 2 , 1 , 3 )
300
304
hidden_states = hidden_states .reshape (- 1 , hidden_states .shape [- 2 ], hidden_states .shape [- 1 ])
301
305
302
- embedded_timestep = embedded_timestep .repeat_interleave (num_frame , dim = 0 ).view (- 1 , embedded_timestep .shape [- 1 ])
306
+ embedded_timestep = embedded_timestep .repeat_interleave (
307
+ num_frame , dim = 0 , output_size = embedded_timestep .shape [0 ] * num_frame
308
+ ).view (- 1 , embedded_timestep .shape [- 1 ])
303
309
shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
304
310
hidden_states = self .norm_out (hidden_states )
305
311
# Modulation
0 commit comments