Spaces:
Running
on
L4
Running
on
L4
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- LICENSE +201 -0
- ORIGINAL_README.md +158 -0
- assets/demo1_audio.wav +0 -0
- assets/demo1_video.mp4 +3 -0
- assets/demo2_audio.wav +0 -0
- assets/demo2_video.mp4 +3 -0
- assets/demo3_audio.wav +0 -0
- assets/demo3_video.mp4 +3 -0
- assets/framework.png +0 -0
- configs/audio.yaml +23 -0
- configs/scheduler_config.json +13 -0
- configs/syncnet/syncnet_16_latent.yaml +46 -0
- configs/syncnet/syncnet_16_pixel.yaml +45 -0
- configs/syncnet/syncnet_25_pixel.yaml +45 -0
- configs/unet/first_stage.yaml +103 -0
- configs/unet/second_stage.yaml +103 -0
- data_processing_pipeline.sh +9 -0
- eval/detectors/README.md +3 -0
- eval/detectors/__init__.py +1 -0
- eval/detectors/s3fd/__init__.py +61 -0
- eval/detectors/s3fd/box_utils.py +221 -0
- eval/detectors/s3fd/nets.py +174 -0
- eval/draw_syncnet_lines.py +70 -0
- eval/eval_fvd.py +96 -0
- eval/eval_sync_conf.py +77 -0
- eval/eval_sync_conf.sh +2 -0
- eval/eval_syncnet_acc.py +118 -0
- eval/eval_syncnet_acc.sh +3 -0
- eval/fvd.py +56 -0
- eval/hyper_iqa.py +343 -0
- eval/inference_videos.py +37 -0
- eval/syncnet/__init__.py +1 -0
- eval/syncnet/syncnet.py +113 -0
- eval/syncnet/syncnet_eval.py +220 -0
- eval/syncnet_detect.py +251 -0
- inference.sh +9 -0
- latentsync/data/syncnet_dataset.py +153 -0
- latentsync/data/unet_dataset.py +164 -0
- latentsync/models/attention.py +492 -0
- latentsync/models/motion_module.py +332 -0
- latentsync/models/resnet.py +234 -0
- latentsync/models/syncnet.py +233 -0
- latentsync/models/syncnet_wav2lip.py +90 -0
- latentsync/models/unet.py +528 -0
- latentsync/models/unet_blocks.py +903 -0
- latentsync/models/utils.py +19 -0
- latentsync/pipelines/lipsync_pipeline.py +470 -0
- latentsync/trepa/__init__.py +64 -0
- latentsync/trepa/third_party/VideoMAEv2/__init__.py +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo1_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/demo2_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/demo3_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LatentSync: Audio Conditioned Latent Diffusion Models for Lip Sync
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
|
5 |
+
[![arXiv](https://img.shields.io/badge/arXiv_paper-2412.09262-b31b1b)](https://arxiv.org/abs/2412.09262)
|
6 |
+
|
7 |
+
</div>
|
8 |
+
|
9 |
+
## 📖 Abstract
|
10 |
+
|
11 |
+
We present *LatentSync*, an end-to-end lip sync framework based on audio conditioned latent diffusion models without any intermediate motion representation, diverging from previous diffusion-based lip sync methods based on pixel space diffusion or two-stage generation. Our framework can leverage the powerful capabilities of Stable Diffusion to directly model complex audio-visual correlations. Additionally, we found that the diffusion-based lip sync methods exhibit inferior temporal consistency due to the inconsistency in the diffusion process across different frames. We propose *Temporal REPresentation Alignment (TREPA)* to enhance temporal consistency while preserving lip-sync accuracy. TREPA uses temporal representations extracted by large-scale self-supervised video models to align the generated frames with the ground truth frames.
|
12 |
+
|
13 |
+
## 🏗️ Framework
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
<img src="assets/framework.png" width=100%>
|
17 |
+
<p>
|
18 |
+
|
19 |
+
LatentSync uses the Whisper to convert melspectrogram into audio embeddings, which are then integrated into the U-Net via cross-attention layers. The reference and masked frames are channel-wise concatenated with noised latents as the input of U-Net. In the training process, we use one-step method to get estimated clean latents from predicted noises, which are then decoded to obtain the estimated clean frames. The TREPA, LPIPS and SyncNet loss are added in the pixel space.
|
20 |
+
|
21 |
+
## 🎬 Demo
|
22 |
+
|
23 |
+
<table class="center">
|
24 |
+
<tr style="font-weight: bolder;text-align:center;">
|
25 |
+
<td width="50%"><b>Original video</b></td>
|
26 |
+
<td width="50%"><b>Lip-synced video</b></td>
|
27 |
+
</tr>
|
28 |
+
<tr>
|
29 |
+
<td>
|
30 |
+
<video src=https://github.com/user-attachments/assets/ff3a84da-dc9b-498a-950f-5c54f58dd5c5 controls preload></video>
|
31 |
+
</td>
|
32 |
+
<td>
|
33 |
+
<video src=https://github.com/user-attachments/assets/150e00fd-381e-4421-a478-a9ea3d1212a8 controls preload></video>
|
34 |
+
</td>
|
35 |
+
</tr>
|
36 |
+
<tr>
|
37 |
+
<td>
|
38 |
+
<video src=https://github.com/user-attachments/assets/32c830a9-4d7d-4044-9b33-b184d8e11010 controls preload></video>
|
39 |
+
</td>
|
40 |
+
<td>
|
41 |
+
<video src=https://github.com/user-attachments/assets/84e4fe9d-b108-44a4-8712-13a012348145 controls preload></video>
|
42 |
+
</td>
|
43 |
+
</tr>
|
44 |
+
<tr>
|
45 |
+
<td>
|
46 |
+
<video src=https://github.com/user-attachments/assets/7510a448-255a-44ee-b093-a1b98bd3961d controls preload></video>
|
47 |
+
</td>
|
48 |
+
<td>
|
49 |
+
<video src=https://github.com/user-attachments/assets/6150c453-c559-4ae0-bb00-c565f135ff41 controls preload></video>
|
50 |
+
</td>
|
51 |
+
</tr>
|
52 |
+
<tr>
|
53 |
+
<td width=300px>
|
54 |
+
<video src=https://github.com/user-attachments/assets/0f7f9845-68b2-4165-bd08-c7bbe01a0e52 controls preload></video>
|
55 |
+
</td>
|
56 |
+
<td width=300px>
|
57 |
+
<video src=https://github.com/user-attachments/assets/c34fe89d-0c09-4de3-8601-3d01229a69e3 controls preload></video>
|
58 |
+
</td>
|
59 |
+
</tr>
|
60 |
+
<tr>
|
61 |
+
<td>
|
62 |
+
<video src=https://github.com/user-attachments/assets/7ce04d50-d39f-4154-932a-ec3a590a8f64 controls preload></video>
|
63 |
+
</td>
|
64 |
+
<td>
|
65 |
+
<video src=https://github.com/user-attachments/assets/70bde520-42fa-4a0e-b66c-d3040ae5e065 controls preload></video>
|
66 |
+
</td>
|
67 |
+
</tr>
|
68 |
+
</table>
|
69 |
+
|
70 |
+
(Photorealistic videos are filmed by contracted models, and anime videos are from [VASA-1](https://www.microsoft.com/en-us/research/project/vasa-1/) and [EMO](https://humanaigc.github.io/emote-portrait-alive/))
|
71 |
+
|
72 |
+
## 📑 Open-source Plan
|
73 |
+
|
74 |
+
- [x] Inference code and checkpoints
|
75 |
+
- [x] Data processing pipeline
|
76 |
+
- [x] Training code
|
77 |
+
|
78 |
+
## 🔧 Setting up the Environment
|
79 |
+
|
80 |
+
Install the required packages and download the checkpoints via:
|
81 |
+
|
82 |
+
```bash
|
83 |
+
source setup_env.sh
|
84 |
+
```
|
85 |
+
|
86 |
+
If the download is successful, the checkpoints should appear as follows:
|
87 |
+
|
88 |
+
```
|
89 |
+
./checkpoints/
|
90 |
+
|-- latentsync_unet.pt
|
91 |
+
|-- latentsync_syncnet.pt
|
92 |
+
|-- whisper
|
93 |
+
| `-- tiny.pt
|
94 |
+
|-- auxiliary
|
95 |
+
| |-- 2DFAN4-cd938726ad.zip
|
96 |
+
| |-- i3d_torchscript.pt
|
97 |
+
| |-- koniq_pretrained.pkl
|
98 |
+
| |-- s3fd-619a316812.pth
|
99 |
+
| |-- sfd_face.pth
|
100 |
+
| |-- syncnet_v2.model
|
101 |
+
| |-- vgg16-397923af.pth
|
102 |
+
| `-- vit_g_hybrid_pt_1200e_ssv2_ft.pth
|
103 |
+
```
|
104 |
+
|
105 |
+
These already include all the checkpoints required for latentsync training and inference. If you just want to try inference, you only need to download `latentsync_unet.pt` and `tiny.pt` from our [HuggingFace repo](https://huggingface.co/chunyu-li/LatentSync)
|
106 |
+
|
107 |
+
## 🚀 Inference
|
108 |
+
|
109 |
+
Run the script for inference, which requires about 6.5 GB GPU memory.
|
110 |
+
|
111 |
+
```bash
|
112 |
+
./inference.sh
|
113 |
+
```
|
114 |
+
|
115 |
+
You can change the parameter `guidance_scale` to 1.5 to improve the lip-sync accuracy.
|
116 |
+
|
117 |
+
## 🔄 Data Processing Pipeline
|
118 |
+
|
119 |
+
The complete data processing pipeline includes the following steps:
|
120 |
+
|
121 |
+
1. Remove the broken video files.
|
122 |
+
2. Resample the video FPS to 25, and resample the audio to 16000 Hz.
|
123 |
+
3. Scene detect via [PySceneDetect](https://github.com/Breakthrough/PySceneDetect).
|
124 |
+
4. Split each video into 5-10 second segments.
|
125 |
+
5. Remove videos where the face is smaller than 256 $\times$ 256, as well as videos with more than one face.
|
126 |
+
6. Affine transform the faces according to the landmarks detected by [face-alignment](https://github.com/1adrianb/face-alignment), then resize to 256 $\times$ 256.
|
127 |
+
7. Remove videos with [sync confidence score](https://www.robots.ox.ac.uk/~vgg/publications/2016/Chung16a/chung16a.pdf) lower than 3, and adjust the audio-visual offset to 0.
|
128 |
+
8. Calculate [hyperIQA](https://openaccess.thecvf.com/content_CVPR_2020/papers/Su_Blindly_Assess_Image_Quality_in_the_Wild_Guided_by_a_CVPR_2020_paper.pdf) score, and remove videos with scores lower than 40.
|
129 |
+
|
130 |
+
Run the script to execute the data processing pipeline:
|
131 |
+
|
132 |
+
```bash
|
133 |
+
./data_processing_pipeline.sh
|
134 |
+
```
|
135 |
+
|
136 |
+
You can change the parameter `input_dir` in the script to specify the data directory to be processed. The processed data will be saved in the same directory. Each step will generate a new directory to prevent the need to redo the entire pipeline in case the process is interrupted by an unexpected error.
|
137 |
+
|
138 |
+
## 🏋️♂️ Training U-Net
|
139 |
+
|
140 |
+
Before training, you must process the data as described above and download all the checkpoints. We released a pretrained SyncNet with 94% accuracy on the VoxCeleb2 dataset for the supervision of U-Net training. Note that this SyncNet is trained on affine transformed videos, so when using or evaluating this SyncNet, you need to perform affine transformation on the video first (the code of affine transformation is included in the data processing pipeline).
|
141 |
+
|
142 |
+
If all the preparations are complete, you can train the U-Net with the following script:
|
143 |
+
|
144 |
+
```bash
|
145 |
+
./train_unet.sh
|
146 |
+
```
|
147 |
+
|
148 |
+
You should change the parameters in U-Net config file to specify the data directory, checkpoint save path, and other training hyperparameters.
|
149 |
+
|
150 |
+
## 🏋️♂️ Training SyncNet
|
151 |
+
|
152 |
+
In case you want to train SyncNet on your own datasets, you can run the following script. The data processing pipeline for SyncNet is the same as U-Net.
|
153 |
+
|
154 |
+
```bash
|
155 |
+
./train_syncnet.sh
|
156 |
+
```
|
157 |
+
|
158 |
+
After `validations_steps` training, the loss charts will be saved in `train_output_dir`. They contain both the training and validation loss.
|
assets/demo1_audio.wav
ADDED
Binary file (307 kB). View file
|
|
assets/demo1_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed2dd1e2001aa605c3f2d77672a8af4ed55e427a85c55d408adfc3d5076bc872
|
3 |
+
size 1240008
|
assets/demo2_audio.wav
ADDED
Binary file (635 kB). View file
|
|
assets/demo2_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c3f10288e0642e587a95c0040e6966f8f6b7e003c3a17b572f72472b896d8ff
|
3 |
+
size 1772492
|
assets/demo3_audio.wav
ADDED
Binary file (594 kB). View file
|
|
assets/demo3_video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfa177b2a44f7809f606285c120e270d526caa50d708ec95e0f614d220970e0f
|
3 |
+
size 2112370
|
assets/framework.png
ADDED
configs/audio.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
audio:
|
2 |
+
num_mels: 80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
3 |
+
rescale: true # Whether to rescale audio prior to preprocessing
|
4 |
+
rescaling_max: 0.9 # Rescaling value
|
5 |
+
use_lws:
|
6 |
+
false # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
7 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
8 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
9 |
+
n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
|
10 |
+
hop_size: 200 # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
11 |
+
win_size: 800 # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
12 |
+
sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
13 |
+
frame_shift_ms: null
|
14 |
+
signal_normalization: true
|
15 |
+
allow_clipping_in_normalization: true
|
16 |
+
symmetric_mels: true
|
17 |
+
max_abs_value: 4.0
|
18 |
+
preemphasize: true # whether to apply filter
|
19 |
+
preemphasis: 0.97 # filter coefficient.
|
20 |
+
min_level_db: -100
|
21 |
+
ref_level_db: 20
|
22 |
+
fmin: 55
|
23 |
+
fmax: 7600
|
configs/scheduler_config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "DDIMScheduler",
|
3 |
+
"_diffusers_version": "0.6.0.dev0",
|
4 |
+
"beta_end": 0.012,
|
5 |
+
"beta_schedule": "scaled_linear",
|
6 |
+
"beta_start": 0.00085,
|
7 |
+
"clip_sample": false,
|
8 |
+
"num_train_timesteps": 1000,
|
9 |
+
"set_alpha_to_one": false,
|
10 |
+
"steps_offset": 1,
|
11 |
+
"trained_betas": null,
|
12 |
+
"skip_prk_steps": true
|
13 |
+
}
|
configs/syncnet/syncnet_16_latent.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
audio_encoder: # input (1, 80, 52)
|
3 |
+
in_channels: 1
|
4 |
+
block_out_channels: [32, 64, 128, 256, 512, 1024]
|
5 |
+
downsample_factors: [[2, 1], 2, 2, 2, 2, [2, 3]]
|
6 |
+
attn_blocks: [0, 0, 0, 0, 0, 0]
|
7 |
+
dropout: 0.0
|
8 |
+
visual_encoder: # input (64, 32, 32)
|
9 |
+
in_channels: 64
|
10 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024]
|
11 |
+
downsample_factors: [2, 2, 2, 1, 2, 2]
|
12 |
+
attn_blocks: [0, 0, 0, 0, 0, 0]
|
13 |
+
dropout: 0.0
|
14 |
+
|
15 |
+
ckpt:
|
16 |
+
resume_ckpt_path: ""
|
17 |
+
inference_ckpt_path: ""
|
18 |
+
save_ckpt_steps: 2500
|
19 |
+
|
20 |
+
data:
|
21 |
+
train_output_dir: output/syncnet
|
22 |
+
num_val_samples: 1200
|
23 |
+
batch_size: 120 # 40
|
24 |
+
num_workers: 11 # 11
|
25 |
+
latent_space: true
|
26 |
+
num_frames: 16
|
27 |
+
resolution: 256
|
28 |
+
train_fileslist: ""
|
29 |
+
train_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/train
|
30 |
+
val_fileslist: ""
|
31 |
+
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
32 |
+
audio_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
33 |
+
lower_half: false
|
34 |
+
pretrained_audio_model_path: facebook/wav2vec2-large-xlsr-53
|
35 |
+
audio_sample_rate: 16000
|
36 |
+
video_fps: 25
|
37 |
+
|
38 |
+
optimizer:
|
39 |
+
lr: 1e-5
|
40 |
+
max_grad_norm: 1.0
|
41 |
+
|
42 |
+
run:
|
43 |
+
max_train_steps: 10000000
|
44 |
+
validation_steps: 2500
|
45 |
+
mixed_precision_training: true
|
46 |
+
seed: 42
|
configs/syncnet/syncnet_16_pixel.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
audio_encoder: # input (1, 80, 52)
|
3 |
+
in_channels: 1
|
4 |
+
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
5 |
+
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
6 |
+
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
7 |
+
dropout: 0.0
|
8 |
+
visual_encoder: # input (48, 128, 256)
|
9 |
+
in_channels: 48
|
10 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
11 |
+
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
12 |
+
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
13 |
+
dropout: 0.0
|
14 |
+
|
15 |
+
ckpt:
|
16 |
+
resume_ckpt_path: ""
|
17 |
+
inference_ckpt_path: checkpoints/latentsync_syncnet.pt
|
18 |
+
save_ckpt_steps: 2500
|
19 |
+
|
20 |
+
data:
|
21 |
+
train_output_dir: debug/syncnet
|
22 |
+
num_val_samples: 2048
|
23 |
+
batch_size: 128 # 128
|
24 |
+
num_workers: 11 # 11
|
25 |
+
latent_space: false
|
26 |
+
num_frames: 16
|
27 |
+
resolution: 256
|
28 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
|
29 |
+
train_data_dir: ""
|
30 |
+
val_fileslist: ""
|
31 |
+
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
32 |
+
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
33 |
+
lower_half: true
|
34 |
+
audio_sample_rate: 16000
|
35 |
+
video_fps: 25
|
36 |
+
|
37 |
+
optimizer:
|
38 |
+
lr: 1e-5
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
run:
|
42 |
+
max_train_steps: 10000000
|
43 |
+
validation_steps: 2500
|
44 |
+
mixed_precision_training: true
|
45 |
+
seed: 42
|
configs/syncnet/syncnet_25_pixel.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
audio_encoder: # input (1, 80, 80)
|
3 |
+
in_channels: 1
|
4 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024]
|
5 |
+
downsample_factors: [2, 2, 2, 2, 2, 2]
|
6 |
+
dropout: 0.0
|
7 |
+
visual_encoder: # input (75, 128, 256)
|
8 |
+
in_channels: 75
|
9 |
+
block_out_channels: [128, 128, 256, 256, 512, 512, 1024, 1024]
|
10 |
+
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
11 |
+
dropout: 0.0
|
12 |
+
|
13 |
+
ckpt:
|
14 |
+
resume_ckpt_path: ""
|
15 |
+
inference_ckpt_path: ""
|
16 |
+
save_ckpt_steps: 2500
|
17 |
+
|
18 |
+
data:
|
19 |
+
train_output_dir: debug/syncnet
|
20 |
+
num_val_samples: 2048
|
21 |
+
batch_size: 64 # 64
|
22 |
+
num_workers: 11 # 11
|
23 |
+
latent_space: false
|
24 |
+
num_frames: 25
|
25 |
+
resolution: 256
|
26 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/hdtf_vox_avatars_ads_affine.txt
|
27 |
+
# /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/hdtf_voxceleb_avatars_affine.txt
|
28 |
+
train_data_dir: ""
|
29 |
+
val_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/vox_affine_val.txt
|
30 |
+
# /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/voxceleb_val.txt
|
31 |
+
val_data_dir: ""
|
32 |
+
audio_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
33 |
+
lower_half: true
|
34 |
+
pretrained_audio_model_path: facebook/wav2vec2-large-xlsr-53
|
35 |
+
audio_sample_rate: 16000
|
36 |
+
video_fps: 25
|
37 |
+
|
38 |
+
optimizer:
|
39 |
+
lr: 1e-5
|
40 |
+
max_grad_norm: 1.0
|
41 |
+
|
42 |
+
run:
|
43 |
+
max_train_steps: 10000000
|
44 |
+
mixed_precision_training: true
|
45 |
+
seed: 42
|
configs/unet/first_stage.yaml
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
syncnet_config_path: configs/syncnet/syncnet_16_pixel.yaml
|
3 |
+
train_output_dir: debug/unet
|
4 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
|
5 |
+
train_data_dir: ""
|
6 |
+
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/whisper_new
|
7 |
+
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
8 |
+
|
9 |
+
val_video_path: assets/demo1_video.mp4
|
10 |
+
val_audio_path: assets/demo1_audio.wav
|
11 |
+
batch_size: 8 # 8
|
12 |
+
num_workers: 11 # 11
|
13 |
+
num_frames: 16
|
14 |
+
resolution: 256
|
15 |
+
mask: fix_mask
|
16 |
+
audio_sample_rate: 16000
|
17 |
+
video_fps: 25
|
18 |
+
|
19 |
+
ckpt:
|
20 |
+
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
21 |
+
save_ckpt_steps: 5000
|
22 |
+
|
23 |
+
run:
|
24 |
+
pixel_space_supervise: false
|
25 |
+
use_syncnet: false
|
26 |
+
sync_loss_weight: 0.05 # 1/283
|
27 |
+
perceptual_loss_weight: 0.1 # 0.1
|
28 |
+
recon_loss_weight: 1 # 1
|
29 |
+
guidance_scale: 1.0 # 1.5 or 1.0
|
30 |
+
trepa_loss_weight: 10
|
31 |
+
inference_steps: 20
|
32 |
+
seed: 1247
|
33 |
+
use_mixed_noise: true
|
34 |
+
mixed_noise_alpha: 1 # 1
|
35 |
+
mixed_precision_training: true
|
36 |
+
enable_gradient_checkpointing: false
|
37 |
+
enable_xformers_memory_efficient_attention: true
|
38 |
+
max_train_steps: 10000000
|
39 |
+
max_train_epochs: -1
|
40 |
+
|
41 |
+
optimizer:
|
42 |
+
lr: 1e-5
|
43 |
+
scale_lr: false
|
44 |
+
max_grad_norm: 1.0
|
45 |
+
lr_scheduler: constant
|
46 |
+
lr_warmup_steps: 0
|
47 |
+
|
48 |
+
model:
|
49 |
+
act_fn: silu
|
50 |
+
add_audio_layer: true
|
51 |
+
custom_audio_layer: false
|
52 |
+
audio_condition_method: cross_attn # Choose between [cross_attn, group_norm]
|
53 |
+
attention_head_dim: 8
|
54 |
+
block_out_channels: [320, 640, 1280, 1280]
|
55 |
+
center_input_sample: false
|
56 |
+
cross_attention_dim: 384
|
57 |
+
down_block_types:
|
58 |
+
[
|
59 |
+
"CrossAttnDownBlock3D",
|
60 |
+
"CrossAttnDownBlock3D",
|
61 |
+
"CrossAttnDownBlock3D",
|
62 |
+
"DownBlock3D",
|
63 |
+
]
|
64 |
+
mid_block_type: UNetMidBlock3DCrossAttn
|
65 |
+
up_block_types:
|
66 |
+
[
|
67 |
+
"UpBlock3D",
|
68 |
+
"CrossAttnUpBlock3D",
|
69 |
+
"CrossAttnUpBlock3D",
|
70 |
+
"CrossAttnUpBlock3D",
|
71 |
+
]
|
72 |
+
downsample_padding: 1
|
73 |
+
flip_sin_to_cos: true
|
74 |
+
freq_shift: 0
|
75 |
+
in_channels: 13 # 49
|
76 |
+
layers_per_block: 2
|
77 |
+
mid_block_scale_factor: 1
|
78 |
+
norm_eps: 1e-5
|
79 |
+
norm_num_groups: 32
|
80 |
+
out_channels: 4 # 16
|
81 |
+
sample_size: 64
|
82 |
+
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
83 |
+
unet_use_cross_frame_attention: false
|
84 |
+
unet_use_temporal_attention: false
|
85 |
+
|
86 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
87 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module, the results are poor
|
88 |
+
# We decied to leave the code here for possible future usage
|
89 |
+
use_motion_module: false
|
90 |
+
motion_module_resolutions: [1, 2, 4, 8]
|
91 |
+
motion_module_mid_block: false
|
92 |
+
motion_module_decoder_only: false
|
93 |
+
motion_module_type: Vanilla
|
94 |
+
motion_module_kwargs:
|
95 |
+
num_attention_heads: 8
|
96 |
+
num_transformer_block: 1
|
97 |
+
attention_block_types:
|
98 |
+
- Temporal_Self
|
99 |
+
- Temporal_Self
|
100 |
+
temporal_position_encoding: true
|
101 |
+
temporal_position_encoding_max_len: 16
|
102 |
+
temporal_attention_dim_div: 1
|
103 |
+
zero_initialize: true
|
configs/unet/second_stage.yaml
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
syncnet_config_path: configs/syncnet/syncnet_16_pixel.yaml
|
3 |
+
train_output_dir: debug/unet
|
4 |
+
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
|
5 |
+
train_data_dir: ""
|
6 |
+
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/whisper_new
|
7 |
+
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel_new
|
8 |
+
|
9 |
+
val_video_path: assets/demo1_video.mp4
|
10 |
+
val_audio_path: assets/demo1_audio.wav
|
11 |
+
batch_size: 2 # 8
|
12 |
+
num_workers: 11 # 11
|
13 |
+
num_frames: 16
|
14 |
+
resolution: 256
|
15 |
+
mask: fix_mask
|
16 |
+
audio_sample_rate: 16000
|
17 |
+
video_fps: 25
|
18 |
+
|
19 |
+
ckpt:
|
20 |
+
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
21 |
+
save_ckpt_steps: 5000
|
22 |
+
|
23 |
+
run:
|
24 |
+
pixel_space_supervise: true
|
25 |
+
use_syncnet: true
|
26 |
+
sync_loss_weight: 0.05 # 1/283
|
27 |
+
perceptual_loss_weight: 0.1 # 0.1
|
28 |
+
recon_loss_weight: 1 # 1
|
29 |
+
guidance_scale: 1.0 # 1.5 or 1.0
|
30 |
+
trepa_loss_weight: 10
|
31 |
+
inference_steps: 20
|
32 |
+
seed: 1247
|
33 |
+
use_mixed_noise: true
|
34 |
+
mixed_noise_alpha: 1 # 1
|
35 |
+
mixed_precision_training: true
|
36 |
+
enable_gradient_checkpointing: false
|
37 |
+
enable_xformers_memory_efficient_attention: true
|
38 |
+
max_train_steps: 10000000
|
39 |
+
max_train_epochs: -1
|
40 |
+
|
41 |
+
optimizer:
|
42 |
+
lr: 1e-5
|
43 |
+
scale_lr: false
|
44 |
+
max_grad_norm: 1.0
|
45 |
+
lr_scheduler: constant
|
46 |
+
lr_warmup_steps: 0
|
47 |
+
|
48 |
+
model:
|
49 |
+
act_fn: silu
|
50 |
+
add_audio_layer: true
|
51 |
+
custom_audio_layer: false
|
52 |
+
audio_condition_method: cross_attn # Choose between [cross_attn, group_norm]
|
53 |
+
attention_head_dim: 8
|
54 |
+
block_out_channels: [320, 640, 1280, 1280]
|
55 |
+
center_input_sample: false
|
56 |
+
cross_attention_dim: 384
|
57 |
+
down_block_types:
|
58 |
+
[
|
59 |
+
"CrossAttnDownBlock3D",
|
60 |
+
"CrossAttnDownBlock3D",
|
61 |
+
"CrossAttnDownBlock3D",
|
62 |
+
"DownBlock3D",
|
63 |
+
]
|
64 |
+
mid_block_type: UNetMidBlock3DCrossAttn
|
65 |
+
up_block_types:
|
66 |
+
[
|
67 |
+
"UpBlock3D",
|
68 |
+
"CrossAttnUpBlock3D",
|
69 |
+
"CrossAttnUpBlock3D",
|
70 |
+
"CrossAttnUpBlock3D",
|
71 |
+
]
|
72 |
+
downsample_padding: 1
|
73 |
+
flip_sin_to_cos: true
|
74 |
+
freq_shift: 0
|
75 |
+
in_channels: 13 # 49
|
76 |
+
layers_per_block: 2
|
77 |
+
mid_block_scale_factor: 1
|
78 |
+
norm_eps: 1e-5
|
79 |
+
norm_num_groups: 32
|
80 |
+
out_channels: 4 # 16
|
81 |
+
sample_size: 64
|
82 |
+
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
83 |
+
unet_use_cross_frame_attention: false
|
84 |
+
unet_use_temporal_attention: false
|
85 |
+
|
86 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
87 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module, the results are poor
|
88 |
+
# We decied to leave the code here for possible future usage
|
89 |
+
use_motion_module: false
|
90 |
+
motion_module_resolutions: [1, 2, 4, 8]
|
91 |
+
motion_module_mid_block: false
|
92 |
+
motion_module_decoder_only: false
|
93 |
+
motion_module_type: Vanilla
|
94 |
+
motion_module_kwargs:
|
95 |
+
num_attention_heads: 8
|
96 |
+
num_transformer_block: 1
|
97 |
+
attention_block_types:
|
98 |
+
- Temporal_Self
|
99 |
+
- Temporal_Self
|
100 |
+
temporal_position_encoding: true
|
101 |
+
temporal_position_encoding_max_len: 16
|
102 |
+
temporal_attention_dim_div: 1
|
103 |
+
zero_initialize: true
|
data_processing_pipeline.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python -m preprocess.data_processing_pipeline \
|
4 |
+
--total_num_workers 20 \
|
5 |
+
--per_gpu_num_workers 20 \
|
6 |
+
--resolution 256 \
|
7 |
+
--sync_conf_threshold 3 \
|
8 |
+
--temp_dir temp \
|
9 |
+
--input_dir /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/raw
|
eval/detectors/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Face detector
|
2 |
+
|
3 |
+
This face detector is adapted from `https://github.com/cs-giung/face-detection-pytorch`.
|
eval/detectors/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .s3fd import S3FD
|
eval/detectors/s3fd/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from .nets import S3FDNet
|
7 |
+
from .box_utils import nms_
|
8 |
+
|
9 |
+
PATH_WEIGHT = 'checkpoints/auxiliary/sfd_face.pth'
|
10 |
+
img_mean = np.array([104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32')
|
11 |
+
|
12 |
+
|
13 |
+
class S3FD():
|
14 |
+
|
15 |
+
def __init__(self, device='cuda'):
|
16 |
+
|
17 |
+
tstamp = time.time()
|
18 |
+
self.device = device
|
19 |
+
|
20 |
+
print('[S3FD] loading with', self.device)
|
21 |
+
self.net = S3FDNet(device=self.device).to(self.device)
|
22 |
+
state_dict = torch.load(PATH_WEIGHT, map_location=self.device)
|
23 |
+
self.net.load_state_dict(state_dict)
|
24 |
+
self.net.eval()
|
25 |
+
print('[S3FD] finished loading (%.4f sec)' % (time.time() - tstamp))
|
26 |
+
|
27 |
+
def detect_faces(self, image, conf_th=0.8, scales=[1]):
|
28 |
+
|
29 |
+
w, h = image.shape[1], image.shape[0]
|
30 |
+
|
31 |
+
bboxes = np.empty(shape=(0, 5))
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
for s in scales:
|
35 |
+
scaled_img = cv2.resize(image, dsize=(0, 0), fx=s, fy=s, interpolation=cv2.INTER_LINEAR)
|
36 |
+
|
37 |
+
scaled_img = np.swapaxes(scaled_img, 1, 2)
|
38 |
+
scaled_img = np.swapaxes(scaled_img, 1, 0)
|
39 |
+
scaled_img = scaled_img[[2, 1, 0], :, :]
|
40 |
+
scaled_img = scaled_img.astype('float32')
|
41 |
+
scaled_img -= img_mean
|
42 |
+
scaled_img = scaled_img[[2, 1, 0], :, :]
|
43 |
+
x = torch.from_numpy(scaled_img).unsqueeze(0).to(self.device)
|
44 |
+
y = self.net(x)
|
45 |
+
|
46 |
+
detections = y.data
|
47 |
+
scale = torch.Tensor([w, h, w, h])
|
48 |
+
|
49 |
+
for i in range(detections.size(1)):
|
50 |
+
j = 0
|
51 |
+
while detections[0, i, j, 0] > conf_th:
|
52 |
+
score = detections[0, i, j, 0]
|
53 |
+
pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
|
54 |
+
bbox = (pt[0], pt[1], pt[2], pt[3], score)
|
55 |
+
bboxes = np.vstack((bboxes, bbox))
|
56 |
+
j += 1
|
57 |
+
|
58 |
+
keep = nms_(bboxes, 0.1)
|
59 |
+
bboxes = bboxes[keep]
|
60 |
+
|
61 |
+
return bboxes
|
eval/detectors/s3fd/box_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from itertools import product as product
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Function
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
|
8 |
+
def nms_(dets, thresh):
|
9 |
+
"""
|
10 |
+
Courtesy of Ross Girshick
|
11 |
+
[https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py]
|
12 |
+
"""
|
13 |
+
x1 = dets[:, 0]
|
14 |
+
y1 = dets[:, 1]
|
15 |
+
x2 = dets[:, 2]
|
16 |
+
y2 = dets[:, 3]
|
17 |
+
scores = dets[:, 4]
|
18 |
+
|
19 |
+
areas = (x2 - x1) * (y2 - y1)
|
20 |
+
order = scores.argsort()[::-1]
|
21 |
+
|
22 |
+
keep = []
|
23 |
+
while order.size > 0:
|
24 |
+
i = order[0]
|
25 |
+
keep.append(int(i))
|
26 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
27 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
28 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
29 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
30 |
+
|
31 |
+
w = np.maximum(0.0, xx2 - xx1)
|
32 |
+
h = np.maximum(0.0, yy2 - yy1)
|
33 |
+
inter = w * h
|
34 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
35 |
+
|
36 |
+
inds = np.where(ovr <= thresh)[0]
|
37 |
+
order = order[inds + 1]
|
38 |
+
|
39 |
+
return np.array(keep).astype(np.int32)
|
40 |
+
|
41 |
+
|
42 |
+
def decode(loc, priors, variances):
|
43 |
+
"""Decode locations from predictions using priors to undo
|
44 |
+
the encoding we did for offset regression at train time.
|
45 |
+
Args:
|
46 |
+
loc (tensor): location predictions for loc layers,
|
47 |
+
Shape: [num_priors,4]
|
48 |
+
priors (tensor): Prior boxes in center-offset form.
|
49 |
+
Shape: [num_priors,4].
|
50 |
+
variances: (list[float]) Variances of priorboxes
|
51 |
+
Return:
|
52 |
+
decoded bounding box predictions
|
53 |
+
"""
|
54 |
+
|
55 |
+
boxes = torch.cat((
|
56 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
57 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
58 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
59 |
+
boxes[:, 2:] += boxes[:, :2]
|
60 |
+
return boxes
|
61 |
+
|
62 |
+
|
63 |
+
def nms(boxes, scores, overlap=0.5, top_k=200):
|
64 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
65 |
+
overlapping bounding boxes for a given object.
|
66 |
+
Args:
|
67 |
+
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
68 |
+
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
69 |
+
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
70 |
+
top_k: (int) The Maximum number of box preds to consider.
|
71 |
+
Return:
|
72 |
+
The indices of the kept boxes with respect to num_priors.
|
73 |
+
"""
|
74 |
+
|
75 |
+
keep = scores.new(scores.size(0)).zero_().long()
|
76 |
+
if boxes.numel() == 0:
|
77 |
+
return keep, 0
|
78 |
+
x1 = boxes[:, 0]
|
79 |
+
y1 = boxes[:, 1]
|
80 |
+
x2 = boxes[:, 2]
|
81 |
+
y2 = boxes[:, 3]
|
82 |
+
area = torch.mul(x2 - x1, y2 - y1)
|
83 |
+
v, idx = scores.sort(0) # sort in ascending order
|
84 |
+
# I = I[v >= 0.01]
|
85 |
+
idx = idx[-top_k:] # indices of the top-k largest vals
|
86 |
+
xx1 = boxes.new()
|
87 |
+
yy1 = boxes.new()
|
88 |
+
xx2 = boxes.new()
|
89 |
+
yy2 = boxes.new()
|
90 |
+
w = boxes.new()
|
91 |
+
h = boxes.new()
|
92 |
+
|
93 |
+
# keep = torch.Tensor()
|
94 |
+
count = 0
|
95 |
+
while idx.numel() > 0:
|
96 |
+
i = idx[-1] # index of current largest val
|
97 |
+
# keep.append(i)
|
98 |
+
keep[count] = i
|
99 |
+
count += 1
|
100 |
+
if idx.size(0) == 1:
|
101 |
+
break
|
102 |
+
idx = idx[:-1] # remove kept element from view
|
103 |
+
# load bboxes of next highest vals
|
104 |
+
with warnings.catch_warnings():
|
105 |
+
# Ignore UserWarning within this block
|
106 |
+
warnings.simplefilter("ignore", category=UserWarning)
|
107 |
+
torch.index_select(x1, 0, idx, out=xx1)
|
108 |
+
torch.index_select(y1, 0, idx, out=yy1)
|
109 |
+
torch.index_select(x2, 0, idx, out=xx2)
|
110 |
+
torch.index_select(y2, 0, idx, out=yy2)
|
111 |
+
# store element-wise max with next highest score
|
112 |
+
xx1 = torch.clamp(xx1, min=x1[i])
|
113 |
+
yy1 = torch.clamp(yy1, min=y1[i])
|
114 |
+
xx2 = torch.clamp(xx2, max=x2[i])
|
115 |
+
yy2 = torch.clamp(yy2, max=y2[i])
|
116 |
+
w.resize_as_(xx2)
|
117 |
+
h.resize_as_(yy2)
|
118 |
+
w = xx2 - xx1
|
119 |
+
h = yy2 - yy1
|
120 |
+
# check sizes of xx1 and xx2.. after each iteration
|
121 |
+
w = torch.clamp(w, min=0.0)
|
122 |
+
h = torch.clamp(h, min=0.0)
|
123 |
+
inter = w * h
|
124 |
+
# IoU = i / (area(a) + area(b) - i)
|
125 |
+
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
126 |
+
union = (rem_areas - inter) + area[i]
|
127 |
+
IoU = inter / union # store result in iou
|
128 |
+
# keep only elements with an IoU <= overlap
|
129 |
+
idx = idx[IoU.le(overlap)]
|
130 |
+
return keep, count
|
131 |
+
|
132 |
+
|
133 |
+
class Detect(object):
|
134 |
+
|
135 |
+
def __init__(self, num_classes=2,
|
136 |
+
top_k=750, nms_thresh=0.3, conf_thresh=0.05,
|
137 |
+
variance=[0.1, 0.2], nms_top_k=5000):
|
138 |
+
|
139 |
+
self.num_classes = num_classes
|
140 |
+
self.top_k = top_k
|
141 |
+
self.nms_thresh = nms_thresh
|
142 |
+
self.conf_thresh = conf_thresh
|
143 |
+
self.variance = variance
|
144 |
+
self.nms_top_k = nms_top_k
|
145 |
+
|
146 |
+
def forward(self, loc_data, conf_data, prior_data):
|
147 |
+
|
148 |
+
num = loc_data.size(0)
|
149 |
+
num_priors = prior_data.size(0)
|
150 |
+
|
151 |
+
conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)
|
152 |
+
batch_priors = prior_data.view(-1, num_priors, 4).expand(num, num_priors, 4)
|
153 |
+
batch_priors = batch_priors.contiguous().view(-1, 4)
|
154 |
+
|
155 |
+
decoded_boxes = decode(loc_data.view(-1, 4), batch_priors, self.variance)
|
156 |
+
decoded_boxes = decoded_boxes.view(num, num_priors, 4)
|
157 |
+
|
158 |
+
output = torch.zeros(num, self.num_classes, self.top_k, 5)
|
159 |
+
|
160 |
+
for i in range(num):
|
161 |
+
boxes = decoded_boxes[i].clone()
|
162 |
+
conf_scores = conf_preds[i].clone()
|
163 |
+
|
164 |
+
for cl in range(1, self.num_classes):
|
165 |
+
c_mask = conf_scores[cl].gt(self.conf_thresh)
|
166 |
+
scores = conf_scores[cl][c_mask]
|
167 |
+
|
168 |
+
if scores.dim() == 0:
|
169 |
+
continue
|
170 |
+
l_mask = c_mask.unsqueeze(1).expand_as(boxes)
|
171 |
+
boxes_ = boxes[l_mask].view(-1, 4)
|
172 |
+
ids, count = nms(boxes_, scores, self.nms_thresh, self.nms_top_k)
|
173 |
+
count = count if count < self.top_k else self.top_k
|
174 |
+
|
175 |
+
output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1), boxes_[ids[:count]]), 1)
|
176 |
+
|
177 |
+
return output
|
178 |
+
|
179 |
+
|
180 |
+
class PriorBox(object):
|
181 |
+
|
182 |
+
def __init__(self, input_size, feature_maps,
|
183 |
+
variance=[0.1, 0.2],
|
184 |
+
min_sizes=[16, 32, 64, 128, 256, 512],
|
185 |
+
steps=[4, 8, 16, 32, 64, 128],
|
186 |
+
clip=False):
|
187 |
+
|
188 |
+
super(PriorBox, self).__init__()
|
189 |
+
|
190 |
+
self.imh = input_size[0]
|
191 |
+
self.imw = input_size[1]
|
192 |
+
self.feature_maps = feature_maps
|
193 |
+
|
194 |
+
self.variance = variance
|
195 |
+
self.min_sizes = min_sizes
|
196 |
+
self.steps = steps
|
197 |
+
self.clip = clip
|
198 |
+
|
199 |
+
def forward(self):
|
200 |
+
mean = []
|
201 |
+
for k, fmap in enumerate(self.feature_maps):
|
202 |
+
feath = fmap[0]
|
203 |
+
featw = fmap[1]
|
204 |
+
for i, j in product(range(feath), range(featw)):
|
205 |
+
f_kw = self.imw / self.steps[k]
|
206 |
+
f_kh = self.imh / self.steps[k]
|
207 |
+
|
208 |
+
cx = (j + 0.5) / f_kw
|
209 |
+
cy = (i + 0.5) / f_kh
|
210 |
+
|
211 |
+
s_kw = self.min_sizes[k] / self.imw
|
212 |
+
s_kh = self.min_sizes[k] / self.imh
|
213 |
+
|
214 |
+
mean += [cx, cy, s_kw, s_kh]
|
215 |
+
|
216 |
+
output = torch.FloatTensor(mean).view(-1, 4)
|
217 |
+
|
218 |
+
if self.clip:
|
219 |
+
output.clamp_(max=1, min=0)
|
220 |
+
|
221 |
+
return output
|
eval/detectors/s3fd/nets.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn.init as init
|
5 |
+
from .box_utils import Detect, PriorBox
|
6 |
+
|
7 |
+
|
8 |
+
class L2Norm(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, n_channels, scale):
|
11 |
+
super(L2Norm, self).__init__()
|
12 |
+
self.n_channels = n_channels
|
13 |
+
self.gamma = scale or None
|
14 |
+
self.eps = 1e-10
|
15 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
16 |
+
self.reset_parameters()
|
17 |
+
|
18 |
+
def reset_parameters(self):
|
19 |
+
init.constant_(self.weight, self.gamma)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
23 |
+
x = torch.div(x, norm)
|
24 |
+
out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
|
25 |
+
return out
|
26 |
+
|
27 |
+
|
28 |
+
class S3FDNet(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self, device='cuda'):
|
31 |
+
super(S3FDNet, self).__init__()
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
self.vgg = nn.ModuleList([
|
35 |
+
nn.Conv2d(3, 64, 3, 1, padding=1),
|
36 |
+
nn.ReLU(inplace=True),
|
37 |
+
nn.Conv2d(64, 64, 3, 1, padding=1),
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.MaxPool2d(2, 2),
|
40 |
+
|
41 |
+
nn.Conv2d(64, 128, 3, 1, padding=1),
|
42 |
+
nn.ReLU(inplace=True),
|
43 |
+
nn.Conv2d(128, 128, 3, 1, padding=1),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.MaxPool2d(2, 2),
|
46 |
+
|
47 |
+
nn.Conv2d(128, 256, 3, 1, padding=1),
|
48 |
+
nn.ReLU(inplace=True),
|
49 |
+
nn.Conv2d(256, 256, 3, 1, padding=1),
|
50 |
+
nn.ReLU(inplace=True),
|
51 |
+
nn.Conv2d(256, 256, 3, 1, padding=1),
|
52 |
+
nn.ReLU(inplace=True),
|
53 |
+
nn.MaxPool2d(2, 2, ceil_mode=True),
|
54 |
+
|
55 |
+
nn.Conv2d(256, 512, 3, 1, padding=1),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
60 |
+
nn.ReLU(inplace=True),
|
61 |
+
nn.MaxPool2d(2, 2),
|
62 |
+
|
63 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
64 |
+
nn.ReLU(inplace=True),
|
65 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
66 |
+
nn.ReLU(inplace=True),
|
67 |
+
nn.Conv2d(512, 512, 3, 1, padding=1),
|
68 |
+
nn.ReLU(inplace=True),
|
69 |
+
nn.MaxPool2d(2, 2),
|
70 |
+
|
71 |
+
nn.Conv2d(512, 1024, 3, 1, padding=6, dilation=6),
|
72 |
+
nn.ReLU(inplace=True),
|
73 |
+
nn.Conv2d(1024, 1024, 1, 1),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
])
|
76 |
+
|
77 |
+
self.L2Norm3_3 = L2Norm(256, 10)
|
78 |
+
self.L2Norm4_3 = L2Norm(512, 8)
|
79 |
+
self.L2Norm5_3 = L2Norm(512, 5)
|
80 |
+
|
81 |
+
self.extras = nn.ModuleList([
|
82 |
+
nn.Conv2d(1024, 256, 1, 1),
|
83 |
+
nn.Conv2d(256, 512, 3, 2, padding=1),
|
84 |
+
nn.Conv2d(512, 128, 1, 1),
|
85 |
+
nn.Conv2d(128, 256, 3, 2, padding=1),
|
86 |
+
])
|
87 |
+
|
88 |
+
self.loc = nn.ModuleList([
|
89 |
+
nn.Conv2d(256, 4, 3, 1, padding=1),
|
90 |
+
nn.Conv2d(512, 4, 3, 1, padding=1),
|
91 |
+
nn.Conv2d(512, 4, 3, 1, padding=1),
|
92 |
+
nn.Conv2d(1024, 4, 3, 1, padding=1),
|
93 |
+
nn.Conv2d(512, 4, 3, 1, padding=1),
|
94 |
+
nn.Conv2d(256, 4, 3, 1, padding=1),
|
95 |
+
])
|
96 |
+
|
97 |
+
self.conf = nn.ModuleList([
|
98 |
+
nn.Conv2d(256, 4, 3, 1, padding=1),
|
99 |
+
nn.Conv2d(512, 2, 3, 1, padding=1),
|
100 |
+
nn.Conv2d(512, 2, 3, 1, padding=1),
|
101 |
+
nn.Conv2d(1024, 2, 3, 1, padding=1),
|
102 |
+
nn.Conv2d(512, 2, 3, 1, padding=1),
|
103 |
+
nn.Conv2d(256, 2, 3, 1, padding=1),
|
104 |
+
])
|
105 |
+
|
106 |
+
self.softmax = nn.Softmax(dim=-1)
|
107 |
+
self.detect = Detect()
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
size = x.size()[2:]
|
111 |
+
sources = list()
|
112 |
+
loc = list()
|
113 |
+
conf = list()
|
114 |
+
|
115 |
+
for k in range(16):
|
116 |
+
x = self.vgg[k](x)
|
117 |
+
s = self.L2Norm3_3(x)
|
118 |
+
sources.append(s)
|
119 |
+
|
120 |
+
for k in range(16, 23):
|
121 |
+
x = self.vgg[k](x)
|
122 |
+
s = self.L2Norm4_3(x)
|
123 |
+
sources.append(s)
|
124 |
+
|
125 |
+
for k in range(23, 30):
|
126 |
+
x = self.vgg[k](x)
|
127 |
+
s = self.L2Norm5_3(x)
|
128 |
+
sources.append(s)
|
129 |
+
|
130 |
+
for k in range(30, len(self.vgg)):
|
131 |
+
x = self.vgg[k](x)
|
132 |
+
sources.append(x)
|
133 |
+
|
134 |
+
# apply extra layers and cache source layer outputs
|
135 |
+
for k, v in enumerate(self.extras):
|
136 |
+
x = F.relu(v(x), inplace=True)
|
137 |
+
if k % 2 == 1:
|
138 |
+
sources.append(x)
|
139 |
+
|
140 |
+
# apply multibox head to source layers
|
141 |
+
loc_x = self.loc[0](sources[0])
|
142 |
+
conf_x = self.conf[0](sources[0])
|
143 |
+
|
144 |
+
max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True)
|
145 |
+
conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1)
|
146 |
+
|
147 |
+
loc.append(loc_x.permute(0, 2, 3, 1).contiguous())
|
148 |
+
conf.append(conf_x.permute(0, 2, 3, 1).contiguous())
|
149 |
+
|
150 |
+
for i in range(1, len(sources)):
|
151 |
+
x = sources[i]
|
152 |
+
conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous())
|
153 |
+
loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous())
|
154 |
+
|
155 |
+
features_maps = []
|
156 |
+
for i in range(len(loc)):
|
157 |
+
feat = []
|
158 |
+
feat += [loc[i].size(1), loc[i].size(2)]
|
159 |
+
features_maps += [feat]
|
160 |
+
|
161 |
+
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
|
162 |
+
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
|
163 |
+
|
164 |
+
with torch.no_grad():
|
165 |
+
self.priorbox = PriorBox(size, features_maps)
|
166 |
+
self.priors = self.priorbox.forward()
|
167 |
+
|
168 |
+
output = self.detect.forward(
|
169 |
+
loc.view(loc.size(0), -1, 4),
|
170 |
+
self.softmax(conf.view(conf.size(0), -1, 2)),
|
171 |
+
self.priors.type(type(x.data)).to(self.device)
|
172 |
+
)
|
173 |
+
|
174 |
+
return output
|
eval/draw_syncnet_lines.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
|
19 |
+
class Chart:
|
20 |
+
def __init__(self):
|
21 |
+
self.loss_list = []
|
22 |
+
|
23 |
+
def add_ckpt(self, ckpt_path, line_name):
|
24 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
25 |
+
train_step_list = ckpt["train_step_list"]
|
26 |
+
train_loss_list = ckpt["train_loss_list"]
|
27 |
+
val_step_list = ckpt["val_step_list"]
|
28 |
+
val_loss_list = ckpt["val_loss_list"]
|
29 |
+
val_step_list = [val_step_list[0]] + val_step_list[4::5]
|
30 |
+
val_loss_list = [val_loss_list[0]] + val_loss_list[4::5]
|
31 |
+
self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list))
|
32 |
+
|
33 |
+
def draw(self, save_path, plot_val=True):
|
34 |
+
# Global settings
|
35 |
+
plt.rcParams["font.size"] = 14
|
36 |
+
plt.rcParams["font.family"] = "serif"
|
37 |
+
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"]
|
38 |
+
plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"]
|
39 |
+
|
40 |
+
# Creating the plot
|
41 |
+
plt.figure(figsize=(7.766, 4.8)) # Golden ratio
|
42 |
+
for loss in self.loss_list:
|
43 |
+
if plot_val:
|
44 |
+
(line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5)
|
45 |
+
line_color = line.get_color()
|
46 |
+
plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color)
|
47 |
+
else:
|
48 |
+
plt.plot(loss[1], loss[2], label=loss[0], linewidth=1)
|
49 |
+
plt.xlabel("Step")
|
50 |
+
plt.ylabel("Loss")
|
51 |
+
legend = plt.legend()
|
52 |
+
# legend = plt.legend(loc='upper right', bbox_to_anchor=(1, 0.82))
|
53 |
+
|
54 |
+
# Adjust the linewidth of legend
|
55 |
+
for line in legend.get_lines():
|
56 |
+
line.set_linewidth(2)
|
57 |
+
|
58 |
+
plt.savefig(save_path, transparent=True)
|
59 |
+
plt.close()
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
chart = Chart()
|
64 |
+
# chart.add_ckpt("output/syncnet/train-2024_10_25-18:14:43/checkpoints/checkpoint-10000.pt", "w/ self-attn")
|
65 |
+
# chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "w/o self-attn")
|
66 |
+
chart.add_ckpt("output/syncnet/train-2024_10_24-21:03:11/checkpoints/checkpoint-10000.pt", "Dim 512")
|
67 |
+
chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "Dim 2048")
|
68 |
+
chart.add_ckpt("output/syncnet/train-2024_10_24-22:37:04/checkpoints/checkpoint-10000.pt", "Dim 4096")
|
69 |
+
chart.add_ckpt("output/syncnet/train-2024_10_25-02:30:17/checkpoints/checkpoint-10000.pt", "Dim 6144")
|
70 |
+
chart.draw("ablation.pdf", plot_val=True)
|
eval/eval_fvd.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import mediapipe as mp
|
16 |
+
import cv2
|
17 |
+
from decord import VideoReader
|
18 |
+
from einops import rearrange
|
19 |
+
import os
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import tqdm
|
23 |
+
from eval.fvd import compute_our_fvd
|
24 |
+
|
25 |
+
|
26 |
+
class FVD:
|
27 |
+
def __init__(self, resolution=(224, 224)):
|
28 |
+
self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
|
29 |
+
self.resolution = resolution
|
30 |
+
|
31 |
+
def detect_face(self, image):
|
32 |
+
height, width = image.shape[:2]
|
33 |
+
# Process the image and detect faces.
|
34 |
+
results = self.face_detector.process(image)
|
35 |
+
|
36 |
+
if not results.detections: # Face not detected
|
37 |
+
raise Exception("Face not detected")
|
38 |
+
|
39 |
+
detection = results.detections[0] # Only use the first face in the image
|
40 |
+
bounding_box = detection.location_data.relative_bounding_box
|
41 |
+
xmin = int(bounding_box.xmin * width)
|
42 |
+
ymin = int(bounding_box.ymin * height)
|
43 |
+
face_width = int(bounding_box.width * width)
|
44 |
+
face_height = int(bounding_box.height * height)
|
45 |
+
|
46 |
+
# Crop the image to the bounding box.
|
47 |
+
xmin = max(0, xmin)
|
48 |
+
ymin = max(0, ymin)
|
49 |
+
xmax = min(width, xmin + face_width)
|
50 |
+
ymax = min(height, ymin + face_height)
|
51 |
+
image = image[ymin:ymax, xmin:xmax]
|
52 |
+
|
53 |
+
return image
|
54 |
+
|
55 |
+
def detect_video(self, video_path, real: bool = True):
|
56 |
+
vr = VideoReader(video_path)
|
57 |
+
video_frames = vr[20:36].asnumpy() # Use one frame per second
|
58 |
+
vr.seek(0) # avoid memory leak
|
59 |
+
faces = []
|
60 |
+
for frame in video_frames:
|
61 |
+
face = self.detect_face(frame)
|
62 |
+
face = cv2.resize(face, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_AREA)
|
63 |
+
faces.append(face)
|
64 |
+
|
65 |
+
if len(faces) != 16:
|
66 |
+
return None
|
67 |
+
faces = np.stack(faces, axis=0) # (f, h, w, c)
|
68 |
+
faces = torch.from_numpy(faces)
|
69 |
+
return faces
|
70 |
+
|
71 |
+
|
72 |
+
def eval_fvd(real_videos_dir, fake_videos_dir):
|
73 |
+
fvd = FVD()
|
74 |
+
real_features_list = []
|
75 |
+
fake_features_list = []
|
76 |
+
for file in tqdm.tqdm(os.listdir(fake_videos_dir)):
|
77 |
+
if file.endswith(".mp4"):
|
78 |
+
real_video_path = os.path.join(real_videos_dir, file.replace("_out.mp4", ".mp4"))
|
79 |
+
fake_video_path = os.path.join(fake_videos_dir, file)
|
80 |
+
real_features = fvd.detect_video(real_video_path, real=True)
|
81 |
+
fake_features = fvd.detect_video(fake_video_path, real=False)
|
82 |
+
if real_features is None or fake_features is None:
|
83 |
+
continue
|
84 |
+
real_features_list.append(real_features)
|
85 |
+
fake_features_list.append(fake_features)
|
86 |
+
|
87 |
+
real_features = torch.stack(real_features_list) / 255.0
|
88 |
+
fake_features = torch.stack(fake_features_list) / 255.0
|
89 |
+
print(compute_our_fvd(real_features, fake_features, device="cpu"))
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
real_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/cross"
|
94 |
+
fake_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/latentsync_cross"
|
95 |
+
|
96 |
+
eval_fvd(real_videos_dir, fake_videos_dir)
|
eval/eval_sync_conf.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
import tqdm
|
18 |
+
from statistics import fmean
|
19 |
+
from eval.syncnet import SyncNetEval
|
20 |
+
from eval.syncnet_detect import SyncNetDetector
|
21 |
+
from latentsync.utils.util import red_text
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
|
26 |
+
syncnet_detector(video_path=video_path, min_track=50)
|
27 |
+
crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
|
28 |
+
if crop_videos == []:
|
29 |
+
raise Exception(red_text(f"Face not detected in {video_path}"))
|
30 |
+
av_offset_list = []
|
31 |
+
conf_list = []
|
32 |
+
for video in crop_videos:
|
33 |
+
av_offset, _, conf = syncnet.evaluate(
|
34 |
+
video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
|
35 |
+
)
|
36 |
+
av_offset_list.append(av_offset)
|
37 |
+
conf_list.append(conf)
|
38 |
+
av_offset = int(fmean(av_offset_list))
|
39 |
+
conf = fmean(conf_list)
|
40 |
+
print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
|
41 |
+
return av_offset, conf
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
parser = argparse.ArgumentParser(description="SyncNet")
|
46 |
+
parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
|
47 |
+
parser.add_argument("--video_path", type=str, default=None, help="")
|
48 |
+
parser.add_argument("--videos_dir", type=str, default="/root/processed")
|
49 |
+
parser.add_argument("--temp_dir", type=str, default="temp", help="")
|
50 |
+
|
51 |
+
args = parser.parse_args()
|
52 |
+
|
53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
54 |
+
|
55 |
+
syncnet = SyncNetEval(device=device)
|
56 |
+
syncnet.loadParameters(args.initial_model)
|
57 |
+
|
58 |
+
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
|
59 |
+
|
60 |
+
if args.video_path is not None:
|
61 |
+
syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
|
62 |
+
else:
|
63 |
+
sync_conf_list = []
|
64 |
+
video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
|
65 |
+
for video_name in tqdm.tqdm(video_names):
|
66 |
+
try:
|
67 |
+
_, conf = syncnet_eval(
|
68 |
+
syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
|
69 |
+
)
|
70 |
+
sync_conf_list.append(conf)
|
71 |
+
except Exception as e:
|
72 |
+
print(e)
|
73 |
+
print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|
eval/eval_sync_conf.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
python -m eval.eval_sync_conf --video_path "RD_Radio1_000_006_out.mp4"
|
eval/eval_syncnet_acc.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
from tqdm.auto import tqdm
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from einops import rearrange
|
20 |
+
from latentsync.models.syncnet import SyncNet
|
21 |
+
from latentsync.data.syncnet_dataset import SyncNetDataset
|
22 |
+
from diffusers import AutoencoderKL
|
23 |
+
from omegaconf import OmegaConf
|
24 |
+
from accelerate.utils import set_seed
|
25 |
+
|
26 |
+
|
27 |
+
def main(config):
|
28 |
+
set_seed(config.run.seed)
|
29 |
+
|
30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
|
32 |
+
if config.data.latent_space:
|
33 |
+
vae = AutoencoderKL.from_pretrained(
|
34 |
+
"runwayml/stable-diffusion-inpainting", subfolder="vae", revision="fp16", torch_dtype=torch.float16
|
35 |
+
)
|
36 |
+
vae.requires_grad_(False)
|
37 |
+
vae.to(device)
|
38 |
+
|
39 |
+
# Dataset and Dataloader setup
|
40 |
+
dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
|
41 |
+
|
42 |
+
test_dataloader = torch.utils.data.DataLoader(
|
43 |
+
dataset,
|
44 |
+
batch_size=config.data.batch_size,
|
45 |
+
shuffle=False,
|
46 |
+
num_workers=config.data.num_workers,
|
47 |
+
drop_last=False,
|
48 |
+
worker_init_fn=dataset.worker_init_fn,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Model
|
52 |
+
syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device)
|
53 |
+
|
54 |
+
print(f"Load checkpoint from: {config.ckpt.inference_ckpt_path}")
|
55 |
+
checkpoint = torch.load(config.ckpt.inference_ckpt_path, map_location=device)
|
56 |
+
|
57 |
+
syncnet.load_state_dict(checkpoint["state_dict"])
|
58 |
+
syncnet.to(dtype=torch.float16)
|
59 |
+
syncnet.requires_grad_(False)
|
60 |
+
syncnet.eval()
|
61 |
+
|
62 |
+
global_step = 0
|
63 |
+
num_val_batches = config.data.num_val_samples // config.data.batch_size
|
64 |
+
progress_bar = tqdm(range(0, num_val_batches), initial=0, desc="Testing accuracy")
|
65 |
+
|
66 |
+
num_correct_preds = 0
|
67 |
+
num_total_preds = 0
|
68 |
+
|
69 |
+
while True:
|
70 |
+
for step, batch in enumerate(test_dataloader):
|
71 |
+
### >>>> Test >>>> ###
|
72 |
+
|
73 |
+
frames = batch["frames"].to(device, dtype=torch.float16)
|
74 |
+
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
|
75 |
+
y = batch["y"].to(device, dtype=torch.float16).squeeze(1)
|
76 |
+
|
77 |
+
if config.data.latent_space:
|
78 |
+
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
frames = vae.encode(frames).latent_dist.sample() * 0.18215
|
82 |
+
|
83 |
+
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
|
84 |
+
else:
|
85 |
+
frames = rearrange(frames, "b f c h w -> b (f c) h w")
|
86 |
+
|
87 |
+
if config.data.lower_half:
|
88 |
+
height = frames.shape[2]
|
89 |
+
frames = frames[:, :, height // 2 :, :]
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
|
93 |
+
|
94 |
+
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
|
95 |
+
|
96 |
+
preds = (sims > 0.5).to(dtype=torch.float16)
|
97 |
+
num_correct_preds += (preds == y).sum().item()
|
98 |
+
num_total_preds += len(sims)
|
99 |
+
|
100 |
+
progress_bar.update(1)
|
101 |
+
global_step += 1
|
102 |
+
|
103 |
+
if global_step >= num_val_batches:
|
104 |
+
progress_bar.close()
|
105 |
+
print(f"Accuracy score: {num_correct_preds / num_total_preds*100:.2f}%")
|
106 |
+
return
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
parser = argparse.ArgumentParser(description="Code to test the accuracy of expert lip-sync discriminator")
|
111 |
+
|
112 |
+
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_latent.yaml")
|
113 |
+
args = parser.parse_args()
|
114 |
+
|
115 |
+
# Load a configuration file
|
116 |
+
config = OmegaConf.load(args.config_path)
|
117 |
+
|
118 |
+
main(config)
|
eval/eval_syncnet_acc.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python -m eval.eval_syncnet_acc --config_path "configs/syncnet/syncnet_16_pixel.yaml"
|
eval/fvd.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/universome/fvd-comparison/blob/master/our_fvd.py
|
2 |
+
|
3 |
+
from typing import Tuple
|
4 |
+
import scipy
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
10 |
+
mu_gen, sigma_gen = compute_stats(feats_fake)
|
11 |
+
mu_real, sigma_real = compute_stats(feats_real)
|
12 |
+
|
13 |
+
m = np.square(mu_gen - mu_real).sum()
|
14 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
15 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
16 |
+
|
17 |
+
return float(fid)
|
18 |
+
|
19 |
+
|
20 |
+
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
21 |
+
mu = feats.mean(axis=0) # [d]
|
22 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
23 |
+
|
24 |
+
return mu, sigma
|
25 |
+
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def compute_our_fvd(videos_fake: np.ndarray, videos_real: np.ndarray, device: str = "cuda") -> float:
|
29 |
+
i3d_path = "checkpoints/auxiliary/i3d_torchscript.pt"
|
30 |
+
i3d_kwargs = dict(
|
31 |
+
rescale=False, resize=False, return_features=True
|
32 |
+
) # Return raw features before the softmax layer.
|
33 |
+
|
34 |
+
with open(i3d_path, "rb") as f:
|
35 |
+
i3d_model = torch.jit.load(f).eval().to(device)
|
36 |
+
|
37 |
+
videos_fake = videos_fake.permute(0, 4, 1, 2, 3).to(device)
|
38 |
+
videos_real = videos_real.permute(0, 4, 1, 2, 3).to(device)
|
39 |
+
|
40 |
+
feats_fake = i3d_model(videos_fake, **i3d_kwargs).cpu().numpy()
|
41 |
+
feats_real = i3d_model(videos_real, **i3d_kwargs).cpu().numpy()
|
42 |
+
|
43 |
+
return compute_fvd(feats_fake, feats_real)
|
44 |
+
|
45 |
+
|
46 |
+
def main():
|
47 |
+
# input shape: (b, f, h, w, c)
|
48 |
+
videos_fake = torch.rand(10, 16, 224, 224, 3)
|
49 |
+
videos_real = torch.rand(10, 16, 224, 224, 3)
|
50 |
+
|
51 |
+
our_fvd_result = compute_our_fvd(videos_fake, videos_real)
|
52 |
+
print(f"[FVD scores] Ours: {our_fvd_result}")
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
main()
|
eval/hyper_iqa.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/SSL92/hyperIQA/blob/master/models.py
|
2 |
+
|
3 |
+
import torch as torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.nn import init
|
7 |
+
import math
|
8 |
+
import torch.utils.model_zoo as model_zoo
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class HyperNet(nn.Module):
|
20 |
+
"""
|
21 |
+
Hyper network for learning perceptual rules.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
lda_out_channels: local distortion aware module output size.
|
25 |
+
hyper_in_channels: input feature channels for hyper network.
|
26 |
+
target_in_size: input vector size for target network.
|
27 |
+
target_fc(i)_size: fully connection layer size of target network.
|
28 |
+
feature_size: input feature map width/height for hyper network.
|
29 |
+
|
30 |
+
Note:
|
31 |
+
For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'.
|
32 |
+
|
33 |
+
"""
|
34 |
+
def __init__(self, lda_out_channels, hyper_in_channels, target_in_size, target_fc1_size, target_fc2_size, target_fc3_size, target_fc4_size, feature_size):
|
35 |
+
super(HyperNet, self).__init__()
|
36 |
+
|
37 |
+
self.hyperInChn = hyper_in_channels
|
38 |
+
self.target_in_size = target_in_size
|
39 |
+
self.f1 = target_fc1_size
|
40 |
+
self.f2 = target_fc2_size
|
41 |
+
self.f3 = target_fc3_size
|
42 |
+
self.f4 = target_fc4_size
|
43 |
+
self.feature_size = feature_size
|
44 |
+
|
45 |
+
self.res = resnet50_backbone(lda_out_channels, target_in_size, pretrained=True)
|
46 |
+
|
47 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
48 |
+
|
49 |
+
# Conv layers for resnet output features
|
50 |
+
self.conv1 = nn.Sequential(
|
51 |
+
nn.Conv2d(2048, 1024, 1, padding=(0, 0)),
|
52 |
+
nn.ReLU(inplace=True),
|
53 |
+
nn.Conv2d(1024, 512, 1, padding=(0, 0)),
|
54 |
+
nn.ReLU(inplace=True),
|
55 |
+
nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)),
|
56 |
+
nn.ReLU(inplace=True)
|
57 |
+
)
|
58 |
+
|
59 |
+
# Hyper network part, conv for generating target fc weights, fc for generating target fc biases
|
60 |
+
self.fc1w_conv = nn.Conv2d(self.hyperInChn, int(self.target_in_size * self.f1 / feature_size ** 2), 3, padding=(1, 1))
|
61 |
+
self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1)
|
62 |
+
|
63 |
+
self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size ** 2), 3, padding=(1, 1))
|
64 |
+
self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2)
|
65 |
+
|
66 |
+
self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size ** 2), 3, padding=(1, 1))
|
67 |
+
self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3)
|
68 |
+
|
69 |
+
self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size ** 2), 3, padding=(1, 1))
|
70 |
+
self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4)
|
71 |
+
|
72 |
+
self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4)
|
73 |
+
self.fc5b_fc = nn.Linear(self.hyperInChn, 1)
|
74 |
+
|
75 |
+
# initialize
|
76 |
+
for i, m_name in enumerate(self._modules):
|
77 |
+
if i > 2:
|
78 |
+
nn.init.kaiming_normal_(self._modules[m_name].weight.data)
|
79 |
+
|
80 |
+
def forward(self, img):
|
81 |
+
feature_size = self.feature_size
|
82 |
+
|
83 |
+
res_out = self.res(img)
|
84 |
+
|
85 |
+
# input vector for target net
|
86 |
+
target_in_vec = res_out['target_in_vec'].reshape(-1, self.target_in_size, 1, 1)
|
87 |
+
|
88 |
+
# input features for hyper net
|
89 |
+
hyper_in_feat = self.conv1(res_out['hyper_in_feat']).reshape(-1, self.hyperInChn, feature_size, feature_size)
|
90 |
+
|
91 |
+
# generating target net weights & biases
|
92 |
+
target_fc1w = self.fc1w_conv(hyper_in_feat).reshape(-1, self.f1, self.target_in_size, 1, 1)
|
93 |
+
target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f1)
|
94 |
+
|
95 |
+
target_fc2w = self.fc2w_conv(hyper_in_feat).reshape(-1, self.f2, self.f1, 1, 1)
|
96 |
+
target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f2)
|
97 |
+
|
98 |
+
target_fc3w = self.fc3w_conv(hyper_in_feat).reshape(-1, self.f3, self.f2, 1, 1)
|
99 |
+
target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f3)
|
100 |
+
|
101 |
+
target_fc4w = self.fc4w_conv(hyper_in_feat).reshape(-1, self.f4, self.f3, 1, 1)
|
102 |
+
target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f4)
|
103 |
+
|
104 |
+
target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, 1, self.f4, 1, 1)
|
105 |
+
target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, 1)
|
106 |
+
|
107 |
+
out = {}
|
108 |
+
out['target_in_vec'] = target_in_vec
|
109 |
+
out['target_fc1w'] = target_fc1w
|
110 |
+
out['target_fc1b'] = target_fc1b
|
111 |
+
out['target_fc2w'] = target_fc2w
|
112 |
+
out['target_fc2b'] = target_fc2b
|
113 |
+
out['target_fc3w'] = target_fc3w
|
114 |
+
out['target_fc3b'] = target_fc3b
|
115 |
+
out['target_fc4w'] = target_fc4w
|
116 |
+
out['target_fc4b'] = target_fc4b
|
117 |
+
out['target_fc5w'] = target_fc5w
|
118 |
+
out['target_fc5b'] = target_fc5b
|
119 |
+
|
120 |
+
return out
|
121 |
+
|
122 |
+
|
123 |
+
class TargetNet(nn.Module):
|
124 |
+
"""
|
125 |
+
Target network for quality prediction.
|
126 |
+
"""
|
127 |
+
def __init__(self, paras):
|
128 |
+
super(TargetNet, self).__init__()
|
129 |
+
self.l1 = nn.Sequential(
|
130 |
+
TargetFC(paras['target_fc1w'], paras['target_fc1b']),
|
131 |
+
nn.Sigmoid(),
|
132 |
+
)
|
133 |
+
self.l2 = nn.Sequential(
|
134 |
+
TargetFC(paras['target_fc2w'], paras['target_fc2b']),
|
135 |
+
nn.Sigmoid(),
|
136 |
+
)
|
137 |
+
|
138 |
+
self.l3 = nn.Sequential(
|
139 |
+
TargetFC(paras['target_fc3w'], paras['target_fc3b']),
|
140 |
+
nn.Sigmoid(),
|
141 |
+
)
|
142 |
+
|
143 |
+
self.l4 = nn.Sequential(
|
144 |
+
TargetFC(paras['target_fc4w'], paras['target_fc4b']),
|
145 |
+
nn.Sigmoid(),
|
146 |
+
TargetFC(paras['target_fc5w'], paras['target_fc5b']),
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
q = self.l1(x)
|
151 |
+
# q = F.dropout(q)
|
152 |
+
q = self.l2(q)
|
153 |
+
q = self.l3(q)
|
154 |
+
q = self.l4(q).squeeze()
|
155 |
+
return q
|
156 |
+
|
157 |
+
|
158 |
+
class TargetFC(nn.Module):
|
159 |
+
"""
|
160 |
+
Fully connection operations for target net
|
161 |
+
|
162 |
+
Note:
|
163 |
+
Weights & biases are different for different images in a batch,
|
164 |
+
thus here we use group convolution for calculating images in a batch with individual weights & biases.
|
165 |
+
"""
|
166 |
+
def __init__(self, weight, bias):
|
167 |
+
super(TargetFC, self).__init__()
|
168 |
+
self.weight = weight
|
169 |
+
self.bias = bias
|
170 |
+
|
171 |
+
def forward(self, input_):
|
172 |
+
|
173 |
+
input_re = input_.reshape(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3])
|
174 |
+
weight_re = self.weight.reshape(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], self.weight.shape[3], self.weight.shape[4])
|
175 |
+
bias_re = self.bias.reshape(self.bias.shape[0] * self.bias.shape[1])
|
176 |
+
out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0])
|
177 |
+
|
178 |
+
return out.reshape(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3])
|
179 |
+
|
180 |
+
|
181 |
+
class Bottleneck(nn.Module):
|
182 |
+
expansion = 4
|
183 |
+
|
184 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
185 |
+
super(Bottleneck, self).__init__()
|
186 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
187 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
188 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
189 |
+
padding=1, bias=False)
|
190 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
191 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
192 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
193 |
+
self.relu = nn.ReLU(inplace=True)
|
194 |
+
self.downsample = downsample
|
195 |
+
self.stride = stride
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
residual = x
|
199 |
+
|
200 |
+
out = self.conv1(x)
|
201 |
+
out = self.bn1(out)
|
202 |
+
out = self.relu(out)
|
203 |
+
|
204 |
+
out = self.conv2(out)
|
205 |
+
out = self.bn2(out)
|
206 |
+
out = self.relu(out)
|
207 |
+
|
208 |
+
out = self.conv3(out)
|
209 |
+
out = self.bn3(out)
|
210 |
+
|
211 |
+
if self.downsample is not None:
|
212 |
+
residual = self.downsample(x)
|
213 |
+
|
214 |
+
out += residual
|
215 |
+
out = self.relu(out)
|
216 |
+
|
217 |
+
return out
|
218 |
+
|
219 |
+
|
220 |
+
class ResNetBackbone(nn.Module):
|
221 |
+
|
222 |
+
def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000):
|
223 |
+
super(ResNetBackbone, self).__init__()
|
224 |
+
self.inplanes = 64
|
225 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
226 |
+
self.bn1 = nn.BatchNorm2d(64)
|
227 |
+
self.relu = nn.ReLU(inplace=True)
|
228 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
229 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
230 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
231 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
232 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
233 |
+
|
234 |
+
# local distortion aware module
|
235 |
+
self.lda1_pool = nn.Sequential(
|
236 |
+
nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
|
237 |
+
nn.AvgPool2d(7, stride=7),
|
238 |
+
)
|
239 |
+
self.lda1_fc = nn.Linear(16 * 64, lda_out_channels)
|
240 |
+
|
241 |
+
self.lda2_pool = nn.Sequential(
|
242 |
+
nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
|
243 |
+
nn.AvgPool2d(7, stride=7),
|
244 |
+
)
|
245 |
+
self.lda2_fc = nn.Linear(32 * 16, lda_out_channels)
|
246 |
+
|
247 |
+
self.lda3_pool = nn.Sequential(
|
248 |
+
nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
|
249 |
+
nn.AvgPool2d(7, stride=7),
|
250 |
+
)
|
251 |
+
self.lda3_fc = nn.Linear(64 * 4, lda_out_channels)
|
252 |
+
|
253 |
+
self.lda4_pool = nn.AvgPool2d(7, stride=7)
|
254 |
+
self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3)
|
255 |
+
|
256 |
+
for m in self.modules():
|
257 |
+
if isinstance(m, nn.Conv2d):
|
258 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
259 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
260 |
+
elif isinstance(m, nn.BatchNorm2d):
|
261 |
+
m.weight.data.fill_(1)
|
262 |
+
m.bias.data.zero_()
|
263 |
+
|
264 |
+
# initialize
|
265 |
+
nn.init.kaiming_normal_(self.lda1_pool._modules['0'].weight.data)
|
266 |
+
nn.init.kaiming_normal_(self.lda2_pool._modules['0'].weight.data)
|
267 |
+
nn.init.kaiming_normal_(self.lda3_pool._modules['0'].weight.data)
|
268 |
+
nn.init.kaiming_normal_(self.lda1_fc.weight.data)
|
269 |
+
nn.init.kaiming_normal_(self.lda2_fc.weight.data)
|
270 |
+
nn.init.kaiming_normal_(self.lda3_fc.weight.data)
|
271 |
+
nn.init.kaiming_normal_(self.lda4_fc.weight.data)
|
272 |
+
|
273 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
274 |
+
downsample = None
|
275 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
276 |
+
downsample = nn.Sequential(
|
277 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
278 |
+
kernel_size=1, stride=stride, bias=False),
|
279 |
+
nn.BatchNorm2d(planes * block.expansion),
|
280 |
+
)
|
281 |
+
|
282 |
+
layers = []
|
283 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
284 |
+
self.inplanes = planes * block.expansion
|
285 |
+
for i in range(1, blocks):
|
286 |
+
layers.append(block(self.inplanes, planes))
|
287 |
+
|
288 |
+
return nn.Sequential(*layers)
|
289 |
+
|
290 |
+
def forward(self, x):
|
291 |
+
x = self.conv1(x)
|
292 |
+
x = self.bn1(x)
|
293 |
+
x = self.relu(x)
|
294 |
+
x = self.maxpool(x)
|
295 |
+
x = self.layer1(x)
|
296 |
+
|
297 |
+
# the same effect as lda operation in the paper, but save much more memory
|
298 |
+
lda_1 = self.lda1_fc(self.lda1_pool(x).reshape(x.size(0), -1))
|
299 |
+
x = self.layer2(x)
|
300 |
+
lda_2 = self.lda2_fc(self.lda2_pool(x).reshape(x.size(0), -1))
|
301 |
+
x = self.layer3(x)
|
302 |
+
lda_3 = self.lda3_fc(self.lda3_pool(x).reshape(x.size(0), -1))
|
303 |
+
x = self.layer4(x)
|
304 |
+
lda_4 = self.lda4_fc(self.lda4_pool(x).reshape(x.size(0), -1))
|
305 |
+
|
306 |
+
vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1)
|
307 |
+
|
308 |
+
out = {}
|
309 |
+
out['hyper_in_feat'] = x
|
310 |
+
out['target_in_vec'] = vec
|
311 |
+
|
312 |
+
return out
|
313 |
+
|
314 |
+
|
315 |
+
def resnet50_backbone(lda_out_channels, in_chn, pretrained=False, **kwargs):
|
316 |
+
"""Constructs a ResNet-50 model_hyper.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet
|
320 |
+
"""
|
321 |
+
model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs)
|
322 |
+
if pretrained:
|
323 |
+
save_model = model_zoo.load_url(model_urls['resnet50'])
|
324 |
+
model_dict = model.state_dict()
|
325 |
+
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
|
326 |
+
model_dict.update(state_dict)
|
327 |
+
model.load_state_dict(model_dict)
|
328 |
+
else:
|
329 |
+
model.apply(weights_init_xavier)
|
330 |
+
return model
|
331 |
+
|
332 |
+
|
333 |
+
def weights_init_xavier(m):
|
334 |
+
classname = m.__class__.__name__
|
335 |
+
# print(classname)
|
336 |
+
# if isinstance(m, nn.Conv2d):
|
337 |
+
if classname.find('Conv') != -1:
|
338 |
+
init.kaiming_normal_(m.weight.data)
|
339 |
+
elif classname.find('Linear') != -1:
|
340 |
+
init.kaiming_normal_(m.weight.data)
|
341 |
+
elif classname.find('BatchNorm2d') != -1:
|
342 |
+
init.uniform_(m.weight.data, 1.0, 0.02)
|
343 |
+
init.constant_(m.bias.data, 0.0)
|
eval/inference_videos.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import subprocess
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
|
20 |
+
def inference_video_from_dir(input_dir, output_dir, unet_config_path, ckpt_path):
|
21 |
+
os.makedirs(output_dir, exist_ok=True)
|
22 |
+
video_names = sorted([f for f in os.listdir(input_dir) if f.endswith(".mp4")])
|
23 |
+
for video_name in tqdm(video_names):
|
24 |
+
video_path = os.path.join(input_dir, video_name)
|
25 |
+
audio_path = os.path.join(input_dir, video_name.replace(".mp4", "_audio.wav"))
|
26 |
+
video_out_path = os.path.join(output_dir, video_name.replace(".mp4", "_out.mp4"))
|
27 |
+
inference_command = f"python inference.py --unet_config_path {unet_config_path} --video_path {video_path} --audio_path {audio_path} --video_out_path {video_out_path} --inference_ckpt_path {ckpt_path} --seed 1247"
|
28 |
+
subprocess.run(inference_command, shell=True)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/segmented/cross"
|
33 |
+
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/segmented/latentsync_cross"
|
34 |
+
unet_config_path = "configs/unet/unet_latent_16_diffusion.yaml"
|
35 |
+
ckpt_path = "output/unet/train-2024_10_08-16:23:43/checkpoints/checkpoint-1920000.pt"
|
36 |
+
|
37 |
+
inference_video_from_dir(input_dir, output_dir, unet_config_path, ckpt_path)
|
eval/syncnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .syncnet_eval import SyncNetEval
|
eval/syncnet/syncnet.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/joonson/syncnet_python/blob/master/SyncNetModel.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def save(model, filename):
|
8 |
+
with open(filename, "wb") as f:
|
9 |
+
torch.save(model, f)
|
10 |
+
print("%s saved." % filename)
|
11 |
+
|
12 |
+
|
13 |
+
def load(filename):
|
14 |
+
net = torch.load(filename)
|
15 |
+
return net
|
16 |
+
|
17 |
+
|
18 |
+
class S(nn.Module):
|
19 |
+
def __init__(self, num_layers_in_fc_layers=1024):
|
20 |
+
super(S, self).__init__()
|
21 |
+
|
22 |
+
self.__nFeatures__ = 24
|
23 |
+
self.__nChs__ = 32
|
24 |
+
self.__midChs__ = 32
|
25 |
+
|
26 |
+
self.netcnnaud = nn.Sequential(
|
27 |
+
nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
28 |
+
nn.BatchNorm2d(64),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1)),
|
31 |
+
nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
32 |
+
nn.BatchNorm2d(192),
|
33 |
+
nn.ReLU(inplace=True),
|
34 |
+
nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 2)),
|
35 |
+
nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)),
|
36 |
+
nn.BatchNorm2d(384),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)),
|
39 |
+
nn.BatchNorm2d(256),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)),
|
42 |
+
nn.BatchNorm2d(256),
|
43 |
+
nn.ReLU(inplace=True),
|
44 |
+
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
|
45 |
+
nn.Conv2d(256, 512, kernel_size=(5, 4), padding=(0, 0)),
|
46 |
+
nn.BatchNorm2d(512),
|
47 |
+
nn.ReLU(),
|
48 |
+
)
|
49 |
+
|
50 |
+
self.netfcaud = nn.Sequential(
|
51 |
+
nn.Linear(512, 512),
|
52 |
+
nn.BatchNorm1d(512),
|
53 |
+
nn.ReLU(),
|
54 |
+
nn.Linear(512, num_layers_in_fc_layers),
|
55 |
+
)
|
56 |
+
|
57 |
+
self.netfclip = nn.Sequential(
|
58 |
+
nn.Linear(512, 512),
|
59 |
+
nn.BatchNorm1d(512),
|
60 |
+
nn.ReLU(),
|
61 |
+
nn.Linear(512, num_layers_in_fc_layers),
|
62 |
+
)
|
63 |
+
|
64 |
+
self.netcnnlip = nn.Sequential(
|
65 |
+
nn.Conv3d(3, 96, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=0),
|
66 |
+
nn.BatchNorm3d(96),
|
67 |
+
nn.ReLU(inplace=True),
|
68 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)),
|
69 |
+
nn.Conv3d(96, 256, kernel_size=(1, 5, 5), stride=(1, 2, 2), padding=(0, 1, 1)),
|
70 |
+
nn.BatchNorm3d(256),
|
71 |
+
nn.ReLU(inplace=True),
|
72 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
|
73 |
+
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
74 |
+
nn.BatchNorm3d(256),
|
75 |
+
nn.ReLU(inplace=True),
|
76 |
+
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
77 |
+
nn.BatchNorm3d(256),
|
78 |
+
nn.ReLU(inplace=True),
|
79 |
+
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
80 |
+
nn.BatchNorm3d(256),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)),
|
83 |
+
nn.Conv3d(256, 512, kernel_size=(1, 6, 6), padding=0),
|
84 |
+
nn.BatchNorm3d(512),
|
85 |
+
nn.ReLU(inplace=True),
|
86 |
+
)
|
87 |
+
|
88 |
+
def forward_aud(self, x):
|
89 |
+
|
90 |
+
mid = self.netcnnaud(x)
|
91 |
+
# N x ch x 24 x M
|
92 |
+
mid = mid.view((mid.size()[0], -1))
|
93 |
+
# N x (ch x 24)
|
94 |
+
out = self.netfcaud(mid)
|
95 |
+
|
96 |
+
return out
|
97 |
+
|
98 |
+
def forward_lip(self, x):
|
99 |
+
|
100 |
+
mid = self.netcnnlip(x)
|
101 |
+
mid = mid.view((mid.size()[0], -1))
|
102 |
+
# N x (ch x 24)
|
103 |
+
out = self.netfclip(mid)
|
104 |
+
|
105 |
+
return out
|
106 |
+
|
107 |
+
def forward_lipfeat(self, x):
|
108 |
+
|
109 |
+
mid = self.netcnnlip(x)
|
110 |
+
out = mid.view((mid.size()[0], -1))
|
111 |
+
# N x (ch x 24)
|
112 |
+
|
113 |
+
return out
|
eval/syncnet/syncnet_eval.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/joonson/syncnet_python/blob/master/SyncNetInstance.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy
|
5 |
+
import time, pdb, argparse, subprocess, os, math, glob
|
6 |
+
import cv2
|
7 |
+
import python_speech_features
|
8 |
+
|
9 |
+
from scipy import signal
|
10 |
+
from scipy.io import wavfile
|
11 |
+
from .syncnet import S
|
12 |
+
from shutil import rmtree
|
13 |
+
|
14 |
+
|
15 |
+
# ==================== Get OFFSET ====================
|
16 |
+
|
17 |
+
# Video 25 FPS, Audio 16000HZ
|
18 |
+
|
19 |
+
|
20 |
+
def calc_pdist(feat1, feat2, vshift=10):
|
21 |
+
win_size = vshift * 2 + 1
|
22 |
+
|
23 |
+
feat2p = torch.nn.functional.pad(feat2, (0, 0, vshift, vshift))
|
24 |
+
|
25 |
+
dists = []
|
26 |
+
|
27 |
+
for i in range(0, len(feat1)):
|
28 |
+
|
29 |
+
dists.append(
|
30 |
+
torch.nn.functional.pairwise_distance(feat1[[i], :].repeat(win_size, 1), feat2p[i : i + win_size, :])
|
31 |
+
)
|
32 |
+
|
33 |
+
return dists
|
34 |
+
|
35 |
+
|
36 |
+
# ==================== MAIN DEF ====================
|
37 |
+
|
38 |
+
|
39 |
+
class SyncNetEval(torch.nn.Module):
|
40 |
+
def __init__(self, dropout=0, num_layers_in_fc_layers=1024, device="cpu"):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.__S__ = S(num_layers_in_fc_layers=num_layers_in_fc_layers).to(device)
|
44 |
+
self.device = device
|
45 |
+
|
46 |
+
def evaluate(self, video_path, temp_dir="temp", batch_size=20, vshift=15):
|
47 |
+
|
48 |
+
self.__S__.eval()
|
49 |
+
|
50 |
+
# ========== ==========
|
51 |
+
# Convert files
|
52 |
+
# ========== ==========
|
53 |
+
|
54 |
+
if os.path.exists(temp_dir):
|
55 |
+
rmtree(temp_dir)
|
56 |
+
|
57 |
+
os.makedirs(temp_dir)
|
58 |
+
|
59 |
+
# temp_video_path = os.path.join(temp_dir, "temp.mp4")
|
60 |
+
# command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -vf scale='224:224' {temp_video_path}"
|
61 |
+
# subprocess.call(command, shell=True)
|
62 |
+
|
63 |
+
command = (
|
64 |
+
f"ffmpeg -loglevel error -nostdin -y -i {video_path} -f image2 {os.path.join(temp_dir, '%06d.jpg')}"
|
65 |
+
)
|
66 |
+
subprocess.call(command, shell=True, stdout=None)
|
67 |
+
|
68 |
+
command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(temp_dir, 'audio.wav')}"
|
69 |
+
subprocess.call(command, shell=True, stdout=None)
|
70 |
+
|
71 |
+
# ========== ==========
|
72 |
+
# Load video
|
73 |
+
# ========== ==========
|
74 |
+
|
75 |
+
images = []
|
76 |
+
|
77 |
+
flist = glob.glob(os.path.join(temp_dir, "*.jpg"))
|
78 |
+
flist.sort()
|
79 |
+
|
80 |
+
for fname in flist:
|
81 |
+
img_input = cv2.imread(fname)
|
82 |
+
img_input = cv2.resize(img_input, (224, 224)) # HARD CODED, CHANGE BEFORE RELEASE
|
83 |
+
images.append(img_input)
|
84 |
+
|
85 |
+
im = numpy.stack(images, axis=3)
|
86 |
+
im = numpy.expand_dims(im, axis=0)
|
87 |
+
im = numpy.transpose(im, (0, 3, 4, 1, 2))
|
88 |
+
|
89 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
90 |
+
|
91 |
+
# ========== ==========
|
92 |
+
# Load audio
|
93 |
+
# ========== ==========
|
94 |
+
|
95 |
+
sample_rate, audio = wavfile.read(os.path.join(temp_dir, "audio.wav"))
|
96 |
+
mfcc = zip(*python_speech_features.mfcc(audio, sample_rate))
|
97 |
+
mfcc = numpy.stack([numpy.array(i) for i in mfcc])
|
98 |
+
|
99 |
+
cc = numpy.expand_dims(numpy.expand_dims(mfcc, axis=0), axis=0)
|
100 |
+
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
|
101 |
+
|
102 |
+
# ========== ==========
|
103 |
+
# Check audio and video input length
|
104 |
+
# ========== ==========
|
105 |
+
|
106 |
+
# if (float(len(audio)) / 16000) != (float(len(images)) / 25):
|
107 |
+
# print(
|
108 |
+
# "WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."
|
109 |
+
# % (float(len(audio)) / 16000, float(len(images)) / 25)
|
110 |
+
# )
|
111 |
+
|
112 |
+
min_length = min(len(images), math.floor(len(audio) / 640))
|
113 |
+
|
114 |
+
# ========== ==========
|
115 |
+
# Generate video and audio feats
|
116 |
+
# ========== ==========
|
117 |
+
|
118 |
+
lastframe = min_length - 5
|
119 |
+
im_feat = []
|
120 |
+
cc_feat = []
|
121 |
+
|
122 |
+
tS = time.time()
|
123 |
+
for i in range(0, lastframe, batch_size):
|
124 |
+
|
125 |
+
im_batch = [imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + batch_size))]
|
126 |
+
im_in = torch.cat(im_batch, 0)
|
127 |
+
im_out = self.__S__.forward_lip(im_in.to(self.device))
|
128 |
+
im_feat.append(im_out.data.cpu())
|
129 |
+
|
130 |
+
cc_batch = [
|
131 |
+
cct[:, :, :, vframe * 4 : vframe * 4 + 20] for vframe in range(i, min(lastframe, i + batch_size))
|
132 |
+
]
|
133 |
+
cc_in = torch.cat(cc_batch, 0)
|
134 |
+
cc_out = self.__S__.forward_aud(cc_in.to(self.device))
|
135 |
+
cc_feat.append(cc_out.data.cpu())
|
136 |
+
|
137 |
+
im_feat = torch.cat(im_feat, 0)
|
138 |
+
cc_feat = torch.cat(cc_feat, 0)
|
139 |
+
|
140 |
+
# ========== ==========
|
141 |
+
# Compute offset
|
142 |
+
# ========== ==========
|
143 |
+
|
144 |
+
dists = calc_pdist(im_feat, cc_feat, vshift=vshift)
|
145 |
+
mean_dists = torch.mean(torch.stack(dists, 1), 1)
|
146 |
+
|
147 |
+
min_dist, minidx = torch.min(mean_dists, 0)
|
148 |
+
|
149 |
+
av_offset = vshift - minidx
|
150 |
+
conf = torch.median(mean_dists) - min_dist
|
151 |
+
|
152 |
+
fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
|
153 |
+
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
|
154 |
+
fconf = torch.median(mean_dists).numpy() - fdist
|
155 |
+
framewise_conf = signal.medfilt(fconf, kernel_size=9)
|
156 |
+
|
157 |
+
# numpy.set_printoptions(formatter={"float": "{: 0.3f}".format})
|
158 |
+
rmtree(temp_dir)
|
159 |
+
return av_offset.item(), min_dist.item(), conf.item()
|
160 |
+
|
161 |
+
def extract_feature(self, opt, videofile):
|
162 |
+
|
163 |
+
self.__S__.eval()
|
164 |
+
|
165 |
+
# ========== ==========
|
166 |
+
# Load video
|
167 |
+
# ========== ==========
|
168 |
+
cap = cv2.VideoCapture(videofile)
|
169 |
+
|
170 |
+
frame_num = 1
|
171 |
+
images = []
|
172 |
+
while frame_num:
|
173 |
+
frame_num += 1
|
174 |
+
ret, image = cap.read()
|
175 |
+
if ret == 0:
|
176 |
+
break
|
177 |
+
|
178 |
+
images.append(image)
|
179 |
+
|
180 |
+
im = numpy.stack(images, axis=3)
|
181 |
+
im = numpy.expand_dims(im, axis=0)
|
182 |
+
im = numpy.transpose(im, (0, 3, 4, 1, 2))
|
183 |
+
|
184 |
+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
|
185 |
+
|
186 |
+
# ========== ==========
|
187 |
+
# Generate video feats
|
188 |
+
# ========== ==========
|
189 |
+
|
190 |
+
lastframe = len(images) - 4
|
191 |
+
im_feat = []
|
192 |
+
|
193 |
+
tS = time.time()
|
194 |
+
for i in range(0, lastframe, opt.batch_size):
|
195 |
+
|
196 |
+
im_batch = [
|
197 |
+
imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + opt.batch_size))
|
198 |
+
]
|
199 |
+
im_in = torch.cat(im_batch, 0)
|
200 |
+
im_out = self.__S__.forward_lipfeat(im_in.to(self.device))
|
201 |
+
im_feat.append(im_out.data.cpu())
|
202 |
+
|
203 |
+
im_feat = torch.cat(im_feat, 0)
|
204 |
+
|
205 |
+
# ========== ==========
|
206 |
+
# Compute offset
|
207 |
+
# ========== ==========
|
208 |
+
|
209 |
+
print("Compute time %.3f sec." % (time.time() - tS))
|
210 |
+
|
211 |
+
return im_feat
|
212 |
+
|
213 |
+
def loadParameters(self, path):
|
214 |
+
loaded_state = torch.load(path, map_location=lambda storage, loc: storage)
|
215 |
+
|
216 |
+
self_state = self.__S__.state_dict()
|
217 |
+
|
218 |
+
for name, param in loaded_state.items():
|
219 |
+
|
220 |
+
self_state[name].copy_(param)
|
eval/syncnet_detect.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/joonson/syncnet_python/blob/master/run_pipeline.py
|
2 |
+
|
3 |
+
import os, pdb, subprocess, glob, cv2
|
4 |
+
import numpy as np
|
5 |
+
from shutil import rmtree
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from scenedetect.video_manager import VideoManager
|
9 |
+
from scenedetect.scene_manager import SceneManager
|
10 |
+
from scenedetect.stats_manager import StatsManager
|
11 |
+
from scenedetect.detectors import ContentDetector
|
12 |
+
|
13 |
+
from scipy.interpolate import interp1d
|
14 |
+
from scipy.io import wavfile
|
15 |
+
from scipy import signal
|
16 |
+
|
17 |
+
from eval.detectors import S3FD
|
18 |
+
|
19 |
+
|
20 |
+
class SyncNetDetector:
|
21 |
+
def __init__(self, device, detect_results_dir="detect_results"):
|
22 |
+
self.s3f_detector = S3FD(device=device)
|
23 |
+
self.detect_results_dir = detect_results_dir
|
24 |
+
|
25 |
+
def __call__(self, video_path: str, min_track=50, scale=False):
|
26 |
+
crop_dir = os.path.join(self.detect_results_dir, "crop")
|
27 |
+
video_dir = os.path.join(self.detect_results_dir, "video")
|
28 |
+
frames_dir = os.path.join(self.detect_results_dir, "frames")
|
29 |
+
temp_dir = os.path.join(self.detect_results_dir, "temp")
|
30 |
+
|
31 |
+
# ========== DELETE EXISTING DIRECTORIES ==========
|
32 |
+
if os.path.exists(crop_dir):
|
33 |
+
rmtree(crop_dir)
|
34 |
+
|
35 |
+
if os.path.exists(video_dir):
|
36 |
+
rmtree(video_dir)
|
37 |
+
|
38 |
+
if os.path.exists(frames_dir):
|
39 |
+
rmtree(frames_dir)
|
40 |
+
|
41 |
+
if os.path.exists(temp_dir):
|
42 |
+
rmtree(temp_dir)
|
43 |
+
|
44 |
+
# ========== MAKE NEW DIRECTORIES ==========
|
45 |
+
|
46 |
+
os.makedirs(crop_dir)
|
47 |
+
os.makedirs(video_dir)
|
48 |
+
os.makedirs(frames_dir)
|
49 |
+
os.makedirs(temp_dir)
|
50 |
+
|
51 |
+
# ========== CONVERT VIDEO AND EXTRACT FRAMES ==========
|
52 |
+
|
53 |
+
if scale:
|
54 |
+
scaled_video_path = os.path.join(video_dir, "scaled.mp4")
|
55 |
+
command = f"ffmpeg -loglevel error -y -nostdin -i {video_path} -vf scale='224:224' {scaled_video_path}"
|
56 |
+
subprocess.run(command, shell=True)
|
57 |
+
video_path = scaled_video_path
|
58 |
+
|
59 |
+
command = f"ffmpeg -y -nostdin -loglevel error -i {video_path} -qscale:v 2 -async 1 -r 25 {os.path.join(video_dir, 'video.mp4')}"
|
60 |
+
subprocess.run(command, shell=True, stdout=None)
|
61 |
+
|
62 |
+
command = f"ffmpeg -y -nostdin -loglevel error -i {os.path.join(video_dir, 'video.mp4')} -qscale:v 2 -f image2 {os.path.join(frames_dir, '%06d.jpg')}"
|
63 |
+
subprocess.run(command, shell=True, stdout=None)
|
64 |
+
|
65 |
+
command = f"ffmpeg -y -nostdin -loglevel error -i {os.path.join(video_dir, 'video.mp4')} -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(video_dir, 'audio.wav')}"
|
66 |
+
subprocess.run(command, shell=True, stdout=None)
|
67 |
+
|
68 |
+
faces = self.detect_face(frames_dir)
|
69 |
+
|
70 |
+
scene = self.scene_detect(video_dir)
|
71 |
+
|
72 |
+
# Face tracking
|
73 |
+
alltracks = []
|
74 |
+
|
75 |
+
for shot in scene:
|
76 |
+
if shot[1].frame_num - shot[0].frame_num >= min_track:
|
77 |
+
alltracks.extend(self.track_face(faces[shot[0].frame_num : shot[1].frame_num], min_track=min_track))
|
78 |
+
|
79 |
+
# Face crop
|
80 |
+
for ii, track in enumerate(alltracks):
|
81 |
+
self.crop_video(track, os.path.join(crop_dir, "%05d" % ii), frames_dir, 25, temp_dir, video_dir)
|
82 |
+
|
83 |
+
rmtree(temp_dir)
|
84 |
+
|
85 |
+
def scene_detect(self, video_dir):
|
86 |
+
video_manager = VideoManager([os.path.join(video_dir, "video.mp4")])
|
87 |
+
stats_manager = StatsManager()
|
88 |
+
scene_manager = SceneManager(stats_manager)
|
89 |
+
# Add ContentDetector algorithm (constructor takes detector options like threshold).
|
90 |
+
scene_manager.add_detector(ContentDetector())
|
91 |
+
base_timecode = video_manager.get_base_timecode()
|
92 |
+
|
93 |
+
video_manager.set_downscale_factor()
|
94 |
+
|
95 |
+
video_manager.start()
|
96 |
+
|
97 |
+
scene_manager.detect_scenes(frame_source=video_manager)
|
98 |
+
|
99 |
+
scene_list = scene_manager.get_scene_list(base_timecode)
|
100 |
+
|
101 |
+
if scene_list == []:
|
102 |
+
scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())]
|
103 |
+
|
104 |
+
return scene_list
|
105 |
+
|
106 |
+
def track_face(self, scenefaces, num_failed_det=25, min_track=50, min_face_size=100):
|
107 |
+
|
108 |
+
iouThres = 0.5 # Minimum IOU between consecutive face detections
|
109 |
+
tracks = []
|
110 |
+
|
111 |
+
while True:
|
112 |
+
track = []
|
113 |
+
for framefaces in scenefaces:
|
114 |
+
for face in framefaces:
|
115 |
+
if track == []:
|
116 |
+
track.append(face)
|
117 |
+
framefaces.remove(face)
|
118 |
+
elif face["frame"] - track[-1]["frame"] <= num_failed_det:
|
119 |
+
iou = bounding_box_iou(face["bbox"], track[-1]["bbox"])
|
120 |
+
if iou > iouThres:
|
121 |
+
track.append(face)
|
122 |
+
framefaces.remove(face)
|
123 |
+
continue
|
124 |
+
else:
|
125 |
+
break
|
126 |
+
|
127 |
+
if track == []:
|
128 |
+
break
|
129 |
+
elif len(track) > min_track:
|
130 |
+
|
131 |
+
framenum = np.array([f["frame"] for f in track])
|
132 |
+
bboxes = np.array([np.array(f["bbox"]) for f in track])
|
133 |
+
|
134 |
+
frame_i = np.arange(framenum[0], framenum[-1] + 1)
|
135 |
+
|
136 |
+
bboxes_i = []
|
137 |
+
for ij in range(0, 4):
|
138 |
+
interpfn = interp1d(framenum, bboxes[:, ij])
|
139 |
+
bboxes_i.append(interpfn(frame_i))
|
140 |
+
bboxes_i = np.stack(bboxes_i, axis=1)
|
141 |
+
|
142 |
+
if (
|
143 |
+
max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1]))
|
144 |
+
> min_face_size
|
145 |
+
):
|
146 |
+
tracks.append({"frame": frame_i, "bbox": bboxes_i})
|
147 |
+
|
148 |
+
return tracks
|
149 |
+
|
150 |
+
def detect_face(self, frames_dir, facedet_scale=0.25):
|
151 |
+
flist = glob.glob(os.path.join(frames_dir, "*.jpg"))
|
152 |
+
flist.sort()
|
153 |
+
|
154 |
+
dets = []
|
155 |
+
|
156 |
+
for fidx, fname in enumerate(flist):
|
157 |
+
image = cv2.imread(fname)
|
158 |
+
|
159 |
+
image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
160 |
+
bboxes = self.s3f_detector.detect_faces(image_np, conf_th=0.9, scales=[facedet_scale])
|
161 |
+
|
162 |
+
dets.append([])
|
163 |
+
for bbox in bboxes:
|
164 |
+
dets[-1].append({"frame": fidx, "bbox": (bbox[:-1]).tolist(), "conf": bbox[-1]})
|
165 |
+
|
166 |
+
return dets
|
167 |
+
|
168 |
+
def crop_video(self, track, cropfile, frames_dir, frame_rate, temp_dir, video_dir, crop_scale=0.4):
|
169 |
+
|
170 |
+
flist = glob.glob(os.path.join(frames_dir, "*.jpg"))
|
171 |
+
flist.sort()
|
172 |
+
|
173 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
174 |
+
vOut = cv2.VideoWriter(cropfile + "t.mp4", fourcc, frame_rate, (224, 224))
|
175 |
+
|
176 |
+
dets = {"x": [], "y": [], "s": []}
|
177 |
+
|
178 |
+
for det in track["bbox"]:
|
179 |
+
|
180 |
+
dets["s"].append(max((det[3] - det[1]), (det[2] - det[0])) / 2)
|
181 |
+
dets["y"].append((det[1] + det[3]) / 2) # crop center x
|
182 |
+
dets["x"].append((det[0] + det[2]) / 2) # crop center y
|
183 |
+
|
184 |
+
# Smooth detections
|
185 |
+
dets["s"] = signal.medfilt(dets["s"], kernel_size=13)
|
186 |
+
dets["x"] = signal.medfilt(dets["x"], kernel_size=13)
|
187 |
+
dets["y"] = signal.medfilt(dets["y"], kernel_size=13)
|
188 |
+
|
189 |
+
for fidx, frame in enumerate(track["frame"]):
|
190 |
+
|
191 |
+
cs = crop_scale
|
192 |
+
|
193 |
+
bs = dets["s"][fidx] # Detection box size
|
194 |
+
bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount
|
195 |
+
|
196 |
+
image = cv2.imread(flist[frame])
|
197 |
+
|
198 |
+
frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), "constant", constant_values=(110, 110))
|
199 |
+
my = dets["y"][fidx] + bsi # BBox center Y
|
200 |
+
mx = dets["x"][fidx] + bsi # BBox center X
|
201 |
+
|
202 |
+
face = frame[int(my - bs) : int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)) : int(mx + bs * (1 + cs))]
|
203 |
+
|
204 |
+
vOut.write(cv2.resize(face, (224, 224)))
|
205 |
+
|
206 |
+
audiotmp = os.path.join(temp_dir, "audio.wav")
|
207 |
+
audiostart = (track["frame"][0]) / frame_rate
|
208 |
+
audioend = (track["frame"][-1] + 1) / frame_rate
|
209 |
+
|
210 |
+
vOut.release()
|
211 |
+
|
212 |
+
# ========== CROP AUDIO FILE ==========
|
213 |
+
|
214 |
+
command = "ffmpeg -y -nostdin -loglevel error -i %s -ss %.3f -to %.3f %s" % (
|
215 |
+
os.path.join(video_dir, "audio.wav"),
|
216 |
+
audiostart,
|
217 |
+
audioend,
|
218 |
+
audiotmp,
|
219 |
+
)
|
220 |
+
output = subprocess.run(command, shell=True, stdout=None)
|
221 |
+
|
222 |
+
sample_rate, audio = wavfile.read(audiotmp)
|
223 |
+
|
224 |
+
# ========== COMBINE AUDIO AND VIDEO FILES ==========
|
225 |
+
|
226 |
+
command = "ffmpeg -y -nostdin -loglevel error -i %st.mp4 -i %s -c:v copy -c:a aac %s.mp4" % (
|
227 |
+
cropfile,
|
228 |
+
audiotmp,
|
229 |
+
cropfile,
|
230 |
+
)
|
231 |
+
output = subprocess.run(command, shell=True, stdout=None)
|
232 |
+
|
233 |
+
os.remove(cropfile + "t.mp4")
|
234 |
+
|
235 |
+
return {"track": track, "proc_track": dets}
|
236 |
+
|
237 |
+
|
238 |
+
def bounding_box_iou(boxA, boxB):
|
239 |
+
xA = max(boxA[0], boxB[0])
|
240 |
+
yA = max(boxA[1], boxB[1])
|
241 |
+
xB = min(boxA[2], boxB[2])
|
242 |
+
yB = min(boxA[3], boxB[3])
|
243 |
+
|
244 |
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
245 |
+
|
246 |
+
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
247 |
+
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
248 |
+
|
249 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
250 |
+
|
251 |
+
return iou
|
inference.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python -m scripts.inference \
|
4 |
+
--unet_config_path "configs/unet/second_stage.yaml" \
|
5 |
+
--inference_ckpt_path "checkpoints/latentsync_unet.pt" \
|
6 |
+
--guidance_scale 1.0 \
|
7 |
+
--video_path "assets/demo1_video.mp4" \
|
8 |
+
--audio_path "assets/demo1_audio.wav" \
|
9 |
+
--video_out_path "video_out.mp4"
|
latentsync/data/syncnet_dataset.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import torch
|
19 |
+
import random
|
20 |
+
from ..utils.util import gather_video_paths_recursively
|
21 |
+
from ..utils.image_processor import ImageProcessor
|
22 |
+
from ..utils.audio import melspectrogram
|
23 |
+
import math
|
24 |
+
|
25 |
+
from decord import AudioReader, VideoReader, cpu
|
26 |
+
|
27 |
+
|
28 |
+
class SyncNetDataset(Dataset):
|
29 |
+
def __init__(self, data_dir: str, fileslist: str, config):
|
30 |
+
if fileslist != "":
|
31 |
+
with open(fileslist) as file:
|
32 |
+
self.video_paths = [line.rstrip() for line in file]
|
33 |
+
elif data_dir != "":
|
34 |
+
self.video_paths = gather_video_paths_recursively(data_dir)
|
35 |
+
else:
|
36 |
+
raise ValueError("data_dir and fileslist cannot be both empty")
|
37 |
+
|
38 |
+
self.resolution = config.data.resolution
|
39 |
+
self.num_frames = config.data.num_frames
|
40 |
+
|
41 |
+
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
|
42 |
+
|
43 |
+
self.audio_sample_rate = config.data.audio_sample_rate
|
44 |
+
self.video_fps = config.data.video_fps
|
45 |
+
self.audio_samples_length = int(
|
46 |
+
config.data.audio_sample_rate // config.data.video_fps * config.data.num_frames
|
47 |
+
)
|
48 |
+
self.image_processor = ImageProcessor(resolution=config.data.resolution, mask="half")
|
49 |
+
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
50 |
+
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.video_paths)
|
54 |
+
|
55 |
+
def read_audio(self, video_path: str):
|
56 |
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
57 |
+
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
58 |
+
return torch.from_numpy(original_mel)
|
59 |
+
|
60 |
+
def crop_audio_window(self, original_mel, start_index):
|
61 |
+
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
62 |
+
end_idx = start_idx + self.mel_window_length
|
63 |
+
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
64 |
+
|
65 |
+
def get_frames(self, video_reader: VideoReader):
|
66 |
+
total_num_frames = len(video_reader)
|
67 |
+
|
68 |
+
start_idx = random.randint(0, total_num_frames - self.num_frames)
|
69 |
+
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
70 |
+
|
71 |
+
while True:
|
72 |
+
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
73 |
+
# wrong_start_idx = random.randint(
|
74 |
+
# max(0, start_idx - 25), min(total_num_frames - self.num_frames, start_idx + 25)
|
75 |
+
# )
|
76 |
+
if wrong_start_idx == start_idx:
|
77 |
+
continue
|
78 |
+
# if wrong_start_idx >= start_idx - self.num_frames and wrong_start_idx <= start_idx + self.num_frames:
|
79 |
+
# continue
|
80 |
+
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
81 |
+
break
|
82 |
+
|
83 |
+
frames = video_reader.get_batch(frames_index).asnumpy()
|
84 |
+
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
85 |
+
|
86 |
+
return frames, wrong_frames, start_idx
|
87 |
+
|
88 |
+
def worker_init_fn(self, worker_id):
|
89 |
+
# Initialize the face mesh object in each worker process,
|
90 |
+
# because the face mesh object cannot be called in subprocesses
|
91 |
+
self.worker_id = worker_id
|
92 |
+
# setattr(self, f"image_processor_{worker_id}", ImageProcessor(self.resolution, self.mask))
|
93 |
+
|
94 |
+
def __getitem__(self, idx):
|
95 |
+
# image_processor = getattr(self, f"image_processor_{self.worker_id}")
|
96 |
+
while True:
|
97 |
+
try:
|
98 |
+
idx = random.randint(0, len(self) - 1)
|
99 |
+
|
100 |
+
# Get video file path
|
101 |
+
video_path = self.video_paths[idx]
|
102 |
+
|
103 |
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
104 |
+
|
105 |
+
if len(vr) < 2 * self.num_frames:
|
106 |
+
continue
|
107 |
+
|
108 |
+
frames, wrong_frames, start_idx = self.get_frames(vr)
|
109 |
+
|
110 |
+
mel_cache_path = os.path.join(
|
111 |
+
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
112 |
+
)
|
113 |
+
|
114 |
+
if os.path.isfile(mel_cache_path):
|
115 |
+
try:
|
116 |
+
original_mel = torch.load(mel_cache_path)
|
117 |
+
except Exception as e:
|
118 |
+
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
119 |
+
os.remove(mel_cache_path)
|
120 |
+
original_mel = self.read_audio(video_path)
|
121 |
+
torch.save(original_mel, mel_cache_path)
|
122 |
+
else:
|
123 |
+
original_mel = self.read_audio(video_path)
|
124 |
+
torch.save(original_mel, mel_cache_path)
|
125 |
+
|
126 |
+
mel = self.crop_audio_window(original_mel, start_idx)
|
127 |
+
|
128 |
+
if mel.shape[-1] != self.mel_window_length:
|
129 |
+
continue
|
130 |
+
|
131 |
+
if random.choice([True, False]):
|
132 |
+
y = torch.ones(1).float()
|
133 |
+
chosen_frames = frames
|
134 |
+
else:
|
135 |
+
y = torch.zeros(1).float()
|
136 |
+
chosen_frames = wrong_frames
|
137 |
+
|
138 |
+
chosen_frames = self.image_processor.process_images(chosen_frames)
|
139 |
+
# chosen_frames, _, _ = image_processor.prepare_masks_and_masked_images(
|
140 |
+
# chosen_frames, affine_transform=True
|
141 |
+
# )
|
142 |
+
|
143 |
+
vr.seek(0) # avoid memory leak
|
144 |
+
break
|
145 |
+
|
146 |
+
except Exception as e: # Handle the exception of face not detcted
|
147 |
+
print(f"{type(e).__name__} - {e} - {video_path}")
|
148 |
+
if "vr" in locals():
|
149 |
+
vr.seek(0) # avoid memory leak
|
150 |
+
|
151 |
+
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
|
152 |
+
|
153 |
+
return sample
|
latentsync/data/unet_dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import torch
|
19 |
+
import random
|
20 |
+
import cv2
|
21 |
+
from ..utils.image_processor import ImageProcessor, load_fixed_mask
|
22 |
+
from ..utils.audio import melspectrogram
|
23 |
+
from decord import AudioReader, VideoReader, cpu
|
24 |
+
|
25 |
+
|
26 |
+
class UNetDataset(Dataset):
|
27 |
+
def __init__(self, train_data_dir: str, config):
|
28 |
+
if config.data.train_fileslist != "":
|
29 |
+
with open(config.data.train_fileslist) as file:
|
30 |
+
self.video_paths = [line.rstrip() for line in file]
|
31 |
+
elif train_data_dir != "":
|
32 |
+
self.video_paths = []
|
33 |
+
for file in os.listdir(train_data_dir):
|
34 |
+
if file.endswith(".mp4"):
|
35 |
+
self.video_paths.append(os.path.join(train_data_dir, file))
|
36 |
+
else:
|
37 |
+
raise ValueError("data_dir and fileslist cannot be both empty")
|
38 |
+
|
39 |
+
self.resolution = config.data.resolution
|
40 |
+
self.num_frames = config.data.num_frames
|
41 |
+
|
42 |
+
if self.num_frames == 16:
|
43 |
+
self.mel_window_length = 52
|
44 |
+
elif self.num_frames == 5:
|
45 |
+
self.mel_window_length = 16
|
46 |
+
else:
|
47 |
+
raise NotImplementedError("Only support 16 and 5 frames now")
|
48 |
+
|
49 |
+
self.audio_sample_rate = config.data.audio_sample_rate
|
50 |
+
self.video_fps = config.data.video_fps
|
51 |
+
self.mask = config.data.mask
|
52 |
+
self.mask_image = load_fixed_mask(self.resolution)
|
53 |
+
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
|
54 |
+
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
55 |
+
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.video_paths)
|
59 |
+
|
60 |
+
def read_audio(self, video_path: str):
|
61 |
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
62 |
+
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
63 |
+
return torch.from_numpy(original_mel)
|
64 |
+
|
65 |
+
def crop_audio_window(self, original_mel, start_index):
|
66 |
+
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
67 |
+
end_idx = start_idx + self.mel_window_length
|
68 |
+
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
69 |
+
|
70 |
+
def get_frames(self, video_reader: VideoReader):
|
71 |
+
total_num_frames = len(video_reader)
|
72 |
+
|
73 |
+
start_idx = random.randint(self.num_frames // 2, total_num_frames - self.num_frames - self.num_frames // 2)
|
74 |
+
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
75 |
+
|
76 |
+
while True:
|
77 |
+
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
78 |
+
if wrong_start_idx > start_idx - self.num_frames and wrong_start_idx < start_idx + self.num_frames:
|
79 |
+
continue
|
80 |
+
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
81 |
+
break
|
82 |
+
|
83 |
+
frames = video_reader.get_batch(frames_index).asnumpy()
|
84 |
+
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
85 |
+
|
86 |
+
return frames, wrong_frames, start_idx
|
87 |
+
|
88 |
+
def worker_init_fn(self, worker_id):
|
89 |
+
# Initialize the face mesh object in each worker process,
|
90 |
+
# because the face mesh object cannot be called in subprocesses
|
91 |
+
self.worker_id = worker_id
|
92 |
+
setattr(
|
93 |
+
self,
|
94 |
+
f"image_processor_{worker_id}",
|
95 |
+
ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
|
96 |
+
)
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
image_processor = getattr(self, f"image_processor_{self.worker_id}")
|
100 |
+
while True:
|
101 |
+
try:
|
102 |
+
idx = random.randint(0, len(self) - 1)
|
103 |
+
|
104 |
+
# Get video file path
|
105 |
+
video_path = self.video_paths[idx]
|
106 |
+
|
107 |
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
108 |
+
|
109 |
+
if len(vr) < 3 * self.num_frames:
|
110 |
+
continue
|
111 |
+
|
112 |
+
continuous_frames, ref_frames, start_idx = self.get_frames(vr)
|
113 |
+
|
114 |
+
if self.load_audio_data:
|
115 |
+
mel_cache_path = os.path.join(
|
116 |
+
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
117 |
+
)
|
118 |
+
|
119 |
+
if os.path.isfile(mel_cache_path):
|
120 |
+
try:
|
121 |
+
original_mel = torch.load(mel_cache_path)
|
122 |
+
except Exception as e:
|
123 |
+
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
124 |
+
os.remove(mel_cache_path)
|
125 |
+
original_mel = self.read_audio(video_path)
|
126 |
+
torch.save(original_mel, mel_cache_path)
|
127 |
+
else:
|
128 |
+
original_mel = self.read_audio(video_path)
|
129 |
+
torch.save(original_mel, mel_cache_path)
|
130 |
+
|
131 |
+
mel = self.crop_audio_window(original_mel, start_idx)
|
132 |
+
|
133 |
+
if mel.shape[-1] != self.mel_window_length:
|
134 |
+
continue
|
135 |
+
else:
|
136 |
+
mel = []
|
137 |
+
|
138 |
+
gt, masked_gt, mask = image_processor.prepare_masks_and_masked_images(
|
139 |
+
continuous_frames, affine_transform=False
|
140 |
+
)
|
141 |
+
|
142 |
+
if self.mask == "fix_mask":
|
143 |
+
ref, _, _ = image_processor.prepare_masks_and_masked_images(ref_frames, affine_transform=False)
|
144 |
+
else:
|
145 |
+
ref = image_processor.process_images(ref_frames)
|
146 |
+
vr.seek(0) # avoid memory leak
|
147 |
+
break
|
148 |
+
|
149 |
+
except Exception as e: # Handle the exception of face not detcted
|
150 |
+
print(f"{type(e).__name__} - {e} - {video_path}")
|
151 |
+
if "vr" in locals():
|
152 |
+
vr.seek(0) # avoid memory leak
|
153 |
+
|
154 |
+
sample = dict(
|
155 |
+
gt=gt,
|
156 |
+
masked_gt=masked_gt,
|
157 |
+
ref=ref,
|
158 |
+
mel=mel,
|
159 |
+
mask=mask,
|
160 |
+
video_path=video_path,
|
161 |
+
start_idx=start_idx,
|
162 |
+
)
|
163 |
+
|
164 |
+
return sample
|
latentsync/models/attention.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from turtle import forward
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.modeling_utils import ModelMixin
|
13 |
+
from diffusers.utils import BaseOutput
|
14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
15 |
+
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
|
16 |
+
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
from .utils import zero_module
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Transformer3DModelOutput(BaseOutput):
|
23 |
+
sample: torch.FloatTensor
|
24 |
+
|
25 |
+
|
26 |
+
if is_xformers_available():
|
27 |
+
import xformers
|
28 |
+
import xformers.ops
|
29 |
+
else:
|
30 |
+
xformers = None
|
31 |
+
|
32 |
+
|
33 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
34 |
+
@register_to_config
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
num_attention_heads: int = 16,
|
38 |
+
attention_head_dim: int = 88,
|
39 |
+
in_channels: Optional[int] = None,
|
40 |
+
num_layers: int = 1,
|
41 |
+
dropout: float = 0.0,
|
42 |
+
norm_num_groups: int = 32,
|
43 |
+
cross_attention_dim: Optional[int] = None,
|
44 |
+
attention_bias: bool = False,
|
45 |
+
activation_fn: str = "geglu",
|
46 |
+
num_embeds_ada_norm: Optional[int] = None,
|
47 |
+
use_linear_projection: bool = False,
|
48 |
+
only_cross_attention: bool = False,
|
49 |
+
upcast_attention: bool = False,
|
50 |
+
use_motion_module: bool = False,
|
51 |
+
unet_use_cross_frame_attention=None,
|
52 |
+
unet_use_temporal_attention=None,
|
53 |
+
add_audio_layer=False,
|
54 |
+
audio_condition_method="cross_attn",
|
55 |
+
custom_audio_layer: bool = False,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.use_linear_projection = use_linear_projection
|
59 |
+
self.num_attention_heads = num_attention_heads
|
60 |
+
self.attention_head_dim = attention_head_dim
|
61 |
+
inner_dim = num_attention_heads * attention_head_dim
|
62 |
+
|
63 |
+
# Define input layers
|
64 |
+
self.in_channels = in_channels
|
65 |
+
|
66 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
67 |
+
if use_linear_projection:
|
68 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
69 |
+
else:
|
70 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
71 |
+
|
72 |
+
if not custom_audio_layer:
|
73 |
+
# Define transformers blocks
|
74 |
+
self.transformer_blocks = nn.ModuleList(
|
75 |
+
[
|
76 |
+
BasicTransformerBlock(
|
77 |
+
inner_dim,
|
78 |
+
num_attention_heads,
|
79 |
+
attention_head_dim,
|
80 |
+
dropout=dropout,
|
81 |
+
cross_attention_dim=cross_attention_dim,
|
82 |
+
activation_fn=activation_fn,
|
83 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
84 |
+
attention_bias=attention_bias,
|
85 |
+
only_cross_attention=only_cross_attention,
|
86 |
+
upcast_attention=upcast_attention,
|
87 |
+
use_motion_module=use_motion_module,
|
88 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
89 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
90 |
+
add_audio_layer=add_audio_layer,
|
91 |
+
custom_audio_layer=custom_audio_layer,
|
92 |
+
audio_condition_method=audio_condition_method,
|
93 |
+
)
|
94 |
+
for d in range(num_layers)
|
95 |
+
]
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
self.transformer_blocks = nn.ModuleList(
|
99 |
+
[
|
100 |
+
AudioTransformerBlock(
|
101 |
+
inner_dim,
|
102 |
+
num_attention_heads,
|
103 |
+
attention_head_dim,
|
104 |
+
dropout=dropout,
|
105 |
+
cross_attention_dim=cross_attention_dim,
|
106 |
+
activation_fn=activation_fn,
|
107 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
108 |
+
attention_bias=attention_bias,
|
109 |
+
only_cross_attention=only_cross_attention,
|
110 |
+
upcast_attention=upcast_attention,
|
111 |
+
use_motion_module=use_motion_module,
|
112 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
113 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
114 |
+
add_audio_layer=add_audio_layer,
|
115 |
+
)
|
116 |
+
for d in range(num_layers)
|
117 |
+
]
|
118 |
+
)
|
119 |
+
|
120 |
+
# 4. Define output layers
|
121 |
+
if use_linear_projection:
|
122 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
123 |
+
else:
|
124 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
125 |
+
|
126 |
+
if custom_audio_layer:
|
127 |
+
self.proj_out = zero_module(self.proj_out)
|
128 |
+
|
129 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
130 |
+
# Input
|
131 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
132 |
+
video_length = hidden_states.shape[2]
|
133 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
134 |
+
|
135 |
+
# No need to do this for audio input, because different audio samples are independent
|
136 |
+
# encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
137 |
+
|
138 |
+
batch, channel, height, weight = hidden_states.shape
|
139 |
+
residual = hidden_states
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
if not self.use_linear_projection:
|
143 |
+
hidden_states = self.proj_in(hidden_states)
|
144 |
+
inner_dim = hidden_states.shape[1]
|
145 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
146 |
+
else:
|
147 |
+
inner_dim = hidden_states.shape[1]
|
148 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
149 |
+
hidden_states = self.proj_in(hidden_states)
|
150 |
+
|
151 |
+
# Blocks
|
152 |
+
for block in self.transformer_blocks:
|
153 |
+
hidden_states = block(
|
154 |
+
hidden_states,
|
155 |
+
encoder_hidden_states=encoder_hidden_states,
|
156 |
+
timestep=timestep,
|
157 |
+
video_length=video_length,
|
158 |
+
)
|
159 |
+
|
160 |
+
# Output
|
161 |
+
if not self.use_linear_projection:
|
162 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
163 |
+
hidden_states = self.proj_out(hidden_states)
|
164 |
+
else:
|
165 |
+
hidden_states = self.proj_out(hidden_states)
|
166 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
167 |
+
|
168 |
+
output = hidden_states + residual
|
169 |
+
|
170 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
171 |
+
if not return_dict:
|
172 |
+
return (output,)
|
173 |
+
|
174 |
+
return Transformer3DModelOutput(sample=output)
|
175 |
+
|
176 |
+
|
177 |
+
class BasicTransformerBlock(nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
dim: int,
|
181 |
+
num_attention_heads: int,
|
182 |
+
attention_head_dim: int,
|
183 |
+
dropout=0.0,
|
184 |
+
cross_attention_dim: Optional[int] = None,
|
185 |
+
activation_fn: str = "geglu",
|
186 |
+
num_embeds_ada_norm: Optional[int] = None,
|
187 |
+
attention_bias: bool = False,
|
188 |
+
only_cross_attention: bool = False,
|
189 |
+
upcast_attention: bool = False,
|
190 |
+
use_motion_module: bool = False,
|
191 |
+
unet_use_cross_frame_attention=None,
|
192 |
+
unet_use_temporal_attention=None,
|
193 |
+
add_audio_layer=False,
|
194 |
+
custom_audio_layer=False,
|
195 |
+
audio_condition_method="cross_attn",
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.only_cross_attention = only_cross_attention
|
199 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
200 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
201 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
202 |
+
self.use_motion_module = use_motion_module
|
203 |
+
self.add_audio_layer = add_audio_layer
|
204 |
+
|
205 |
+
# SC-Attn
|
206 |
+
assert unet_use_cross_frame_attention is not None
|
207 |
+
if unet_use_cross_frame_attention:
|
208 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
209 |
+
else:
|
210 |
+
self.attn1 = CrossAttention(
|
211 |
+
query_dim=dim,
|
212 |
+
heads=num_attention_heads,
|
213 |
+
dim_head=attention_head_dim,
|
214 |
+
dropout=dropout,
|
215 |
+
bias=attention_bias,
|
216 |
+
upcast_attention=upcast_attention,
|
217 |
+
)
|
218 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
219 |
+
|
220 |
+
# Cross-Attn
|
221 |
+
if add_audio_layer and audio_condition_method == "cross_attn" and not custom_audio_layer:
|
222 |
+
self.audio_cross_attn = AudioCrossAttn(
|
223 |
+
dim=dim,
|
224 |
+
cross_attention_dim=cross_attention_dim,
|
225 |
+
num_attention_heads=num_attention_heads,
|
226 |
+
attention_head_dim=attention_head_dim,
|
227 |
+
dropout=dropout,
|
228 |
+
attention_bias=attention_bias,
|
229 |
+
upcast_attention=upcast_attention,
|
230 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
231 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
232 |
+
zero_proj_out=False,
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
self.audio_cross_attn = None
|
236 |
+
|
237 |
+
# Feed-forward
|
238 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
239 |
+
self.norm3 = nn.LayerNorm(dim)
|
240 |
+
|
241 |
+
# Temp-Attn
|
242 |
+
assert unet_use_temporal_attention is not None
|
243 |
+
if unet_use_temporal_attention:
|
244 |
+
self.attn_temp = CrossAttention(
|
245 |
+
query_dim=dim,
|
246 |
+
heads=num_attention_heads,
|
247 |
+
dim_head=attention_head_dim,
|
248 |
+
dropout=dropout,
|
249 |
+
bias=attention_bias,
|
250 |
+
upcast_attention=upcast_attention,
|
251 |
+
)
|
252 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
253 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
254 |
+
|
255 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
256 |
+
if not is_xformers_available():
|
257 |
+
print("Here is how to install it")
|
258 |
+
raise ModuleNotFoundError(
|
259 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
260 |
+
" xformers",
|
261 |
+
name="xformers",
|
262 |
+
)
|
263 |
+
elif not torch.cuda.is_available():
|
264 |
+
raise ValueError(
|
265 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
266 |
+
" available for GPU "
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
try:
|
270 |
+
# Make sure we can run the memory efficient attention
|
271 |
+
_ = xformers.ops.memory_efficient_attention(
|
272 |
+
torch.randn((1, 2, 40), device="cuda"),
|
273 |
+
torch.randn((1, 2, 40), device="cuda"),
|
274 |
+
torch.randn((1, 2, 40), device="cuda"),
|
275 |
+
)
|
276 |
+
except Exception as e:
|
277 |
+
raise e
|
278 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
279 |
+
if self.audio_cross_attn is not None:
|
280 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
281 |
+
use_memory_efficient_attention_xformers
|
282 |
+
)
|
283 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
284 |
+
|
285 |
+
def forward(
|
286 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
287 |
+
):
|
288 |
+
# SparseCausal-Attention
|
289 |
+
norm_hidden_states = (
|
290 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
291 |
+
)
|
292 |
+
|
293 |
+
# if self.only_cross_attention:
|
294 |
+
# hidden_states = (
|
295 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
296 |
+
# )
|
297 |
+
# else:
|
298 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
299 |
+
|
300 |
+
# pdb.set_trace()
|
301 |
+
if self.unet_use_cross_frame_attention:
|
302 |
+
hidden_states = (
|
303 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
304 |
+
+ hidden_states
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
308 |
+
|
309 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
310 |
+
hidden_states = self.audio_cross_attn(
|
311 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
312 |
+
)
|
313 |
+
|
314 |
+
# Feed-forward
|
315 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
316 |
+
|
317 |
+
# Temporal-Attention
|
318 |
+
if self.unet_use_temporal_attention:
|
319 |
+
d = hidden_states.shape[1]
|
320 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
321 |
+
norm_hidden_states = (
|
322 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
323 |
+
)
|
324 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
325 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
326 |
+
|
327 |
+
return hidden_states
|
328 |
+
|
329 |
+
|
330 |
+
class AudioTransformerBlock(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
dim: int,
|
334 |
+
num_attention_heads: int,
|
335 |
+
attention_head_dim: int,
|
336 |
+
dropout=0.0,
|
337 |
+
cross_attention_dim: Optional[int] = None,
|
338 |
+
activation_fn: str = "geglu",
|
339 |
+
num_embeds_ada_norm: Optional[int] = None,
|
340 |
+
attention_bias: bool = False,
|
341 |
+
only_cross_attention: bool = False,
|
342 |
+
upcast_attention: bool = False,
|
343 |
+
use_motion_module: bool = False,
|
344 |
+
unet_use_cross_frame_attention=None,
|
345 |
+
unet_use_temporal_attention=None,
|
346 |
+
add_audio_layer=False,
|
347 |
+
):
|
348 |
+
super().__init__()
|
349 |
+
self.only_cross_attention = only_cross_attention
|
350 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
351 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
352 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
353 |
+
self.use_motion_module = use_motion_module
|
354 |
+
self.add_audio_layer = add_audio_layer
|
355 |
+
|
356 |
+
# SC-Attn
|
357 |
+
assert unet_use_cross_frame_attention is not None
|
358 |
+
if unet_use_cross_frame_attention:
|
359 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
360 |
+
else:
|
361 |
+
self.attn1 = CrossAttention(
|
362 |
+
query_dim=dim,
|
363 |
+
heads=num_attention_heads,
|
364 |
+
dim_head=attention_head_dim,
|
365 |
+
dropout=dropout,
|
366 |
+
bias=attention_bias,
|
367 |
+
upcast_attention=upcast_attention,
|
368 |
+
)
|
369 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
370 |
+
|
371 |
+
self.audio_cross_attn = AudioCrossAttn(
|
372 |
+
dim=dim,
|
373 |
+
cross_attention_dim=cross_attention_dim,
|
374 |
+
num_attention_heads=num_attention_heads,
|
375 |
+
attention_head_dim=attention_head_dim,
|
376 |
+
dropout=dropout,
|
377 |
+
attention_bias=attention_bias,
|
378 |
+
upcast_attention=upcast_attention,
|
379 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
380 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
381 |
+
zero_proj_out=False,
|
382 |
+
)
|
383 |
+
|
384 |
+
# Feed-forward
|
385 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
386 |
+
self.norm3 = nn.LayerNorm(dim)
|
387 |
+
|
388 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
389 |
+
if not is_xformers_available():
|
390 |
+
print("Here is how to install it")
|
391 |
+
raise ModuleNotFoundError(
|
392 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
393 |
+
" xformers",
|
394 |
+
name="xformers",
|
395 |
+
)
|
396 |
+
elif not torch.cuda.is_available():
|
397 |
+
raise ValueError(
|
398 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
399 |
+
" available for GPU "
|
400 |
+
)
|
401 |
+
else:
|
402 |
+
try:
|
403 |
+
# Make sure we can run the memory efficient attention
|
404 |
+
_ = xformers.ops.memory_efficient_attention(
|
405 |
+
torch.randn((1, 2, 40), device="cuda"),
|
406 |
+
torch.randn((1, 2, 40), device="cuda"),
|
407 |
+
torch.randn((1, 2, 40), device="cuda"),
|
408 |
+
)
|
409 |
+
except Exception as e:
|
410 |
+
raise e
|
411 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
412 |
+
if self.audio_cross_attn is not None:
|
413 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
414 |
+
use_memory_efficient_attention_xformers
|
415 |
+
)
|
416 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
417 |
+
|
418 |
+
def forward(
|
419 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
420 |
+
):
|
421 |
+
# SparseCausal-Attention
|
422 |
+
norm_hidden_states = (
|
423 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
424 |
+
)
|
425 |
+
|
426 |
+
# pdb.set_trace()
|
427 |
+
if self.unet_use_cross_frame_attention:
|
428 |
+
hidden_states = (
|
429 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
430 |
+
+ hidden_states
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
434 |
+
|
435 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
436 |
+
hidden_states = self.audio_cross_attn(
|
437 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
438 |
+
)
|
439 |
+
|
440 |
+
# Feed-forward
|
441 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
442 |
+
|
443 |
+
return hidden_states
|
444 |
+
|
445 |
+
|
446 |
+
class AudioCrossAttn(nn.Module):
|
447 |
+
def __init__(
|
448 |
+
self,
|
449 |
+
dim,
|
450 |
+
cross_attention_dim,
|
451 |
+
num_attention_heads,
|
452 |
+
attention_head_dim,
|
453 |
+
dropout,
|
454 |
+
attention_bias,
|
455 |
+
upcast_attention,
|
456 |
+
num_embeds_ada_norm,
|
457 |
+
use_ada_layer_norm,
|
458 |
+
zero_proj_out=False,
|
459 |
+
):
|
460 |
+
super().__init__()
|
461 |
+
|
462 |
+
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) if use_ada_layer_norm else nn.LayerNorm(dim)
|
463 |
+
self.attn = CrossAttention(
|
464 |
+
query_dim=dim,
|
465 |
+
cross_attention_dim=cross_attention_dim,
|
466 |
+
heads=num_attention_heads,
|
467 |
+
dim_head=attention_head_dim,
|
468 |
+
dropout=dropout,
|
469 |
+
bias=attention_bias,
|
470 |
+
upcast_attention=upcast_attention,
|
471 |
+
)
|
472 |
+
|
473 |
+
if zero_proj_out:
|
474 |
+
self.proj_out = zero_module(nn.Linear(dim, dim))
|
475 |
+
|
476 |
+
self.zero_proj_out = zero_proj_out
|
477 |
+
self.use_ada_layer_norm = use_ada_layer_norm
|
478 |
+
|
479 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
480 |
+
previous_hidden_states = hidden_states
|
481 |
+
hidden_states = self.norm(hidden_states, timestep) if self.use_ada_layer_norm else self.norm(hidden_states)
|
482 |
+
|
483 |
+
if encoder_hidden_states.dim() == 4:
|
484 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, "b f n d -> (b f) n d")
|
485 |
+
|
486 |
+
hidden_states = self.attn(
|
487 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
488 |
+
)
|
489 |
+
|
490 |
+
if self.zero_proj_out:
|
491 |
+
hidden_states = self.proj_out(hidden_states)
|
492 |
+
return hidden_states + previous_hidden_states
|
latentsync/models/motion_module.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
2 |
+
|
3 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
4 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module
|
5 |
+
# But the results are poor, and we decied to leave the code here for possible future usage
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
14 |
+
from diffusers.modeling_utils import ModelMixin
|
15 |
+
from diffusers.utils import BaseOutput
|
16 |
+
from diffusers.utils.import_utils import is_xformers_available
|
17 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
18 |
+
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
import math
|
21 |
+
from .utils import zero_module
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
26 |
+
sample: torch.FloatTensor
|
27 |
+
|
28 |
+
|
29 |
+
if is_xformers_available():
|
30 |
+
import xformers
|
31 |
+
import xformers.ops
|
32 |
+
else:
|
33 |
+
xformers = None
|
34 |
+
|
35 |
+
|
36 |
+
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
37 |
+
if motion_module_type == "Vanilla":
|
38 |
+
return VanillaTemporalModule(
|
39 |
+
in_channels=in_channels,
|
40 |
+
**motion_module_kwargs,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
raise ValueError
|
44 |
+
|
45 |
+
|
46 |
+
class VanillaTemporalModule(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
in_channels,
|
50 |
+
num_attention_heads=8,
|
51 |
+
num_transformer_block=2,
|
52 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
53 |
+
cross_frame_attention_mode=None,
|
54 |
+
temporal_position_encoding=False,
|
55 |
+
temporal_position_encoding_max_len=24,
|
56 |
+
temporal_attention_dim_div=1,
|
57 |
+
zero_initialize=True,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
62 |
+
in_channels=in_channels,
|
63 |
+
num_attention_heads=num_attention_heads,
|
64 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
65 |
+
num_layers=num_transformer_block,
|
66 |
+
attention_block_types=attention_block_types,
|
67 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
68 |
+
temporal_position_encoding=temporal_position_encoding,
|
69 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
70 |
+
)
|
71 |
+
|
72 |
+
if zero_initialize:
|
73 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
74 |
+
|
75 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
76 |
+
hidden_states = input_tensor
|
77 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
78 |
+
|
79 |
+
output = hidden_states
|
80 |
+
return output
|
81 |
+
|
82 |
+
|
83 |
+
class TemporalTransformer3DModel(nn.Module):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
in_channels,
|
87 |
+
num_attention_heads,
|
88 |
+
attention_head_dim,
|
89 |
+
num_layers,
|
90 |
+
attention_block_types=(
|
91 |
+
"Temporal_Self",
|
92 |
+
"Temporal_Self",
|
93 |
+
),
|
94 |
+
dropout=0.0,
|
95 |
+
norm_num_groups=32,
|
96 |
+
cross_attention_dim=768,
|
97 |
+
activation_fn="geglu",
|
98 |
+
attention_bias=False,
|
99 |
+
upcast_attention=False,
|
100 |
+
cross_frame_attention_mode=None,
|
101 |
+
temporal_position_encoding=False,
|
102 |
+
temporal_position_encoding_max_len=24,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
inner_dim = num_attention_heads * attention_head_dim
|
107 |
+
|
108 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
109 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
110 |
+
|
111 |
+
self.transformer_blocks = nn.ModuleList(
|
112 |
+
[
|
113 |
+
TemporalTransformerBlock(
|
114 |
+
dim=inner_dim,
|
115 |
+
num_attention_heads=num_attention_heads,
|
116 |
+
attention_head_dim=attention_head_dim,
|
117 |
+
attention_block_types=attention_block_types,
|
118 |
+
dropout=dropout,
|
119 |
+
norm_num_groups=norm_num_groups,
|
120 |
+
cross_attention_dim=cross_attention_dim,
|
121 |
+
activation_fn=activation_fn,
|
122 |
+
attention_bias=attention_bias,
|
123 |
+
upcast_attention=upcast_attention,
|
124 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
125 |
+
temporal_position_encoding=temporal_position_encoding,
|
126 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
127 |
+
)
|
128 |
+
for d in range(num_layers)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
132 |
+
|
133 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
134 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
135 |
+
video_length = hidden_states.shape[2]
|
136 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
137 |
+
|
138 |
+
batch, channel, height, weight = hidden_states.shape
|
139 |
+
residual = hidden_states
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
143 |
+
hidden_states = self.proj_in(hidden_states)
|
144 |
+
|
145 |
+
# Transformer Blocks
|
146 |
+
for block in self.transformer_blocks:
|
147 |
+
hidden_states = block(
|
148 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
|
149 |
+
)
|
150 |
+
|
151 |
+
# output
|
152 |
+
hidden_states = self.proj_out(hidden_states)
|
153 |
+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
|
154 |
+
|
155 |
+
output = hidden_states + residual
|
156 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
157 |
+
|
158 |
+
return output
|
159 |
+
|
160 |
+
|
161 |
+
class TemporalTransformerBlock(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
dim,
|
165 |
+
num_attention_heads,
|
166 |
+
attention_head_dim,
|
167 |
+
attention_block_types=(
|
168 |
+
"Temporal_Self",
|
169 |
+
"Temporal_Self",
|
170 |
+
),
|
171 |
+
dropout=0.0,
|
172 |
+
norm_num_groups=32,
|
173 |
+
cross_attention_dim=768,
|
174 |
+
activation_fn="geglu",
|
175 |
+
attention_bias=False,
|
176 |
+
upcast_attention=False,
|
177 |
+
cross_frame_attention_mode=None,
|
178 |
+
temporal_position_encoding=False,
|
179 |
+
temporal_position_encoding_max_len=24,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
attention_blocks = []
|
184 |
+
norms = []
|
185 |
+
|
186 |
+
for block_name in attention_block_types:
|
187 |
+
attention_blocks.append(
|
188 |
+
VersatileAttention(
|
189 |
+
attention_mode=block_name.split("_")[0],
|
190 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
191 |
+
query_dim=dim,
|
192 |
+
heads=num_attention_heads,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
dropout=dropout,
|
195 |
+
bias=attention_bias,
|
196 |
+
upcast_attention=upcast_attention,
|
197 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
198 |
+
temporal_position_encoding=temporal_position_encoding,
|
199 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
norms.append(nn.LayerNorm(dim))
|
203 |
+
|
204 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
205 |
+
self.norms = nn.ModuleList(norms)
|
206 |
+
|
207 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
208 |
+
self.ff_norm = nn.LayerNorm(dim)
|
209 |
+
|
210 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
211 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
212 |
+
norm_hidden_states = norm(hidden_states)
|
213 |
+
hidden_states = (
|
214 |
+
attention_block(
|
215 |
+
norm_hidden_states,
|
216 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
217 |
+
video_length=video_length,
|
218 |
+
)
|
219 |
+
+ hidden_states
|
220 |
+
)
|
221 |
+
|
222 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
223 |
+
|
224 |
+
output = hidden_states
|
225 |
+
return output
|
226 |
+
|
227 |
+
|
228 |
+
class PositionalEncoding(nn.Module):
|
229 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
230 |
+
super().__init__()
|
231 |
+
self.dropout = nn.Dropout(p=dropout)
|
232 |
+
position = torch.arange(max_len).unsqueeze(1)
|
233 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
234 |
+
pe = torch.zeros(1, max_len, d_model)
|
235 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
236 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
237 |
+
self.register_buffer("pe", pe)
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
x = x + self.pe[:, : x.size(1)]
|
241 |
+
return self.dropout(x)
|
242 |
+
|
243 |
+
|
244 |
+
class VersatileAttention(CrossAttention):
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
attention_mode=None,
|
248 |
+
cross_frame_attention_mode=None,
|
249 |
+
temporal_position_encoding=False,
|
250 |
+
temporal_position_encoding_max_len=24,
|
251 |
+
*args,
|
252 |
+
**kwargs,
|
253 |
+
):
|
254 |
+
super().__init__(*args, **kwargs)
|
255 |
+
assert attention_mode == "Temporal"
|
256 |
+
|
257 |
+
self.attention_mode = attention_mode
|
258 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
259 |
+
|
260 |
+
self.pos_encoder = (
|
261 |
+
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
|
262 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
263 |
+
else None
|
264 |
+
)
|
265 |
+
|
266 |
+
def extra_repr(self):
|
267 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
268 |
+
|
269 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
270 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
271 |
+
|
272 |
+
if self.attention_mode == "Temporal":
|
273 |
+
d = hidden_states.shape[1]
|
274 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
275 |
+
|
276 |
+
if self.pos_encoder is not None:
|
277 |
+
hidden_states = self.pos_encoder(hidden_states)
|
278 |
+
|
279 |
+
encoder_hidden_states = (
|
280 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
281 |
+
if encoder_hidden_states is not None
|
282 |
+
else encoder_hidden_states
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
raise NotImplementedError
|
286 |
+
|
287 |
+
# encoder_hidden_states = encoder_hidden_states
|
288 |
+
|
289 |
+
if self.group_norm is not None:
|
290 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
291 |
+
|
292 |
+
query = self.to_q(hidden_states)
|
293 |
+
dim = query.shape[-1]
|
294 |
+
query = self.reshape_heads_to_batch_dim(query)
|
295 |
+
|
296 |
+
if self.added_kv_proj_dim is not None:
|
297 |
+
raise NotImplementedError
|
298 |
+
|
299 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
300 |
+
key = self.to_k(encoder_hidden_states)
|
301 |
+
value = self.to_v(encoder_hidden_states)
|
302 |
+
|
303 |
+
key = self.reshape_heads_to_batch_dim(key)
|
304 |
+
value = self.reshape_heads_to_batch_dim(value)
|
305 |
+
|
306 |
+
if attention_mask is not None:
|
307 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
308 |
+
target_length = query.shape[1]
|
309 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
310 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
311 |
+
|
312 |
+
# attention, what we cannot get enough of
|
313 |
+
if self._use_memory_efficient_attention_xformers:
|
314 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
315 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
316 |
+
hidden_states = hidden_states.to(query.dtype)
|
317 |
+
else:
|
318 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
319 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
320 |
+
else:
|
321 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
322 |
+
|
323 |
+
# linear proj
|
324 |
+
hidden_states = self.to_out[0](hidden_states)
|
325 |
+
|
326 |
+
# dropout
|
327 |
+
hidden_states = self.to_out[1](hidden_states)
|
328 |
+
|
329 |
+
if self.attention_mode == "Temporal":
|
330 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
331 |
+
|
332 |
+
return hidden_states
|
latentsync/models/resnet.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class InflatedConv3d(nn.Conv2d):
|
11 |
+
def forward(self, x):
|
12 |
+
video_length = x.shape[2]
|
13 |
+
|
14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
15 |
+
x = super().forward(x)
|
16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
17 |
+
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
video_length = x.shape[2]
|
24 |
+
|
25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
26 |
+
x = super().forward(x)
|
27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
28 |
+
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class Upsample3D(nn.Module):
|
33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
34 |
+
super().__init__()
|
35 |
+
self.channels = channels
|
36 |
+
self.out_channels = out_channels or channels
|
37 |
+
self.use_conv = use_conv
|
38 |
+
self.use_conv_transpose = use_conv_transpose
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
conv = None
|
42 |
+
if use_conv_transpose:
|
43 |
+
raise NotImplementedError
|
44 |
+
elif use_conv:
|
45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
46 |
+
|
47 |
+
def forward(self, hidden_states, output_size=None):
|
48 |
+
assert hidden_states.shape[1] == self.channels
|
49 |
+
|
50 |
+
if self.use_conv_transpose:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
54 |
+
dtype = hidden_states.dtype
|
55 |
+
if dtype == torch.bfloat16:
|
56 |
+
hidden_states = hidden_states.to(torch.float32)
|
57 |
+
|
58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
59 |
+
if hidden_states.shape[0] >= 64:
|
60 |
+
hidden_states = hidden_states.contiguous()
|
61 |
+
|
62 |
+
# if `output_size` is passed we force the interpolation output
|
63 |
+
# size and do not make use of `scale_factor=2`
|
64 |
+
if output_size is None:
|
65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
66 |
+
else:
|
67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
68 |
+
|
69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
70 |
+
if dtype == torch.bfloat16:
|
71 |
+
hidden_states = hidden_states.to(dtype)
|
72 |
+
|
73 |
+
# if self.use_conv:
|
74 |
+
# if self.name == "conv":
|
75 |
+
# hidden_states = self.conv(hidden_states)
|
76 |
+
# else:
|
77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
78 |
+
hidden_states = self.conv(hidden_states)
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
|
83 |
+
class Downsample3D(nn.Module):
|
84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
85 |
+
super().__init__()
|
86 |
+
self.channels = channels
|
87 |
+
self.out_channels = out_channels or channels
|
88 |
+
self.use_conv = use_conv
|
89 |
+
self.padding = padding
|
90 |
+
stride = 2
|
91 |
+
self.name = name
|
92 |
+
|
93 |
+
if use_conv:
|
94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
assert hidden_states.shape[1] == self.channels
|
100 |
+
if self.use_conv and self.padding == 0:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
assert hidden_states.shape[1] == self.channels
|
104 |
+
hidden_states = self.conv(hidden_states)
|
105 |
+
|
106 |
+
return hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class ResnetBlock3D(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
*,
|
113 |
+
in_channels,
|
114 |
+
out_channels=None,
|
115 |
+
conv_shortcut=False,
|
116 |
+
dropout=0.0,
|
117 |
+
temb_channels=512,
|
118 |
+
groups=32,
|
119 |
+
groups_out=None,
|
120 |
+
pre_norm=True,
|
121 |
+
eps=1e-6,
|
122 |
+
non_linearity="swish",
|
123 |
+
time_embedding_norm="default",
|
124 |
+
output_scale_factor=1.0,
|
125 |
+
use_in_shortcut=None,
|
126 |
+
use_inflated_groupnorm=False,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.pre_norm = pre_norm
|
130 |
+
self.pre_norm = True
|
131 |
+
self.in_channels = in_channels
|
132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
133 |
+
self.out_channels = out_channels
|
134 |
+
self.use_conv_shortcut = conv_shortcut
|
135 |
+
self.time_embedding_norm = time_embedding_norm
|
136 |
+
self.output_scale_factor = output_scale_factor
|
137 |
+
|
138 |
+
if groups_out is None:
|
139 |
+
groups_out = groups
|
140 |
+
|
141 |
+
assert use_inflated_groupnorm != None
|
142 |
+
if use_inflated_groupnorm:
|
143 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
144 |
+
else:
|
145 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
146 |
+
|
147 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
148 |
+
|
149 |
+
if temb_channels is not None:
|
150 |
+
time_emb_proj_out_channels = out_channels
|
151 |
+
# if self.time_embedding_norm == "default":
|
152 |
+
# time_emb_proj_out_channels = out_channels
|
153 |
+
# elif self.time_embedding_norm == "scale_shift":
|
154 |
+
# time_emb_proj_out_channels = out_channels * 2
|
155 |
+
# else:
|
156 |
+
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
157 |
+
|
158 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
159 |
+
else:
|
160 |
+
self.time_emb_proj = None
|
161 |
+
|
162 |
+
if self.time_embedding_norm == "scale_shift":
|
163 |
+
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
|
164 |
+
else:
|
165 |
+
self.double_len_linear = None
|
166 |
+
|
167 |
+
if use_inflated_groupnorm:
|
168 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
169 |
+
else:
|
170 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
171 |
+
|
172 |
+
self.dropout = torch.nn.Dropout(dropout)
|
173 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
174 |
+
|
175 |
+
if non_linearity == "swish":
|
176 |
+
self.nonlinearity = lambda x: F.silu(x)
|
177 |
+
elif non_linearity == "mish":
|
178 |
+
self.nonlinearity = Mish()
|
179 |
+
elif non_linearity == "silu":
|
180 |
+
self.nonlinearity = nn.SiLU()
|
181 |
+
|
182 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
183 |
+
|
184 |
+
self.conv_shortcut = None
|
185 |
+
if self.use_in_shortcut:
|
186 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
187 |
+
|
188 |
+
def forward(self, input_tensor, temb):
|
189 |
+
hidden_states = input_tensor
|
190 |
+
|
191 |
+
hidden_states = self.norm1(hidden_states)
|
192 |
+
hidden_states = self.nonlinearity(hidden_states)
|
193 |
+
|
194 |
+
hidden_states = self.conv1(hidden_states)
|
195 |
+
|
196 |
+
if temb is not None:
|
197 |
+
if temb.dim() == 2:
|
198 |
+
# input (1, 1280)
|
199 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
200 |
+
temb = temb[:, :, None, None, None] # unsqueeze
|
201 |
+
else:
|
202 |
+
# input (1, 1280, 16)
|
203 |
+
temb = temb.permute(0, 2, 1)
|
204 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
205 |
+
if self.double_len_linear is not None:
|
206 |
+
temb = self.double_len_linear(self.nonlinearity(temb))
|
207 |
+
temb = temb.permute(0, 2, 1)
|
208 |
+
temb = temb[:, :, :, None, None]
|
209 |
+
|
210 |
+
if temb is not None and self.time_embedding_norm == "default":
|
211 |
+
hidden_states = hidden_states + temb
|
212 |
+
|
213 |
+
hidden_states = self.norm2(hidden_states)
|
214 |
+
|
215 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
216 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
217 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
218 |
+
|
219 |
+
hidden_states = self.nonlinearity(hidden_states)
|
220 |
+
|
221 |
+
hidden_states = self.dropout(hidden_states)
|
222 |
+
hidden_states = self.conv2(hidden_states)
|
223 |
+
|
224 |
+
if self.conv_shortcut is not None:
|
225 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
226 |
+
|
227 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
228 |
+
|
229 |
+
return output_tensor
|
230 |
+
|
231 |
+
|
232 |
+
class Mish(torch.nn.Module):
|
233 |
+
def forward(self, hidden_states):
|
234 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
latentsync/models/syncnet.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from einops import rearrange
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from ..utils.util import cosine_loss
|
20 |
+
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
25 |
+
from diffusers.utils.import_utils import is_xformers_available
|
26 |
+
from einops import rearrange
|
27 |
+
|
28 |
+
|
29 |
+
class SyncNet(nn.Module):
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__()
|
32 |
+
self.audio_encoder = DownEncoder2D(
|
33 |
+
in_channels=config["audio_encoder"]["in_channels"],
|
34 |
+
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
35 |
+
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
36 |
+
dropout=config["audio_encoder"]["dropout"],
|
37 |
+
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
38 |
+
)
|
39 |
+
|
40 |
+
self.visual_encoder = DownEncoder2D(
|
41 |
+
in_channels=config["visual_encoder"]["in_channels"],
|
42 |
+
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
43 |
+
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
44 |
+
dropout=config["visual_encoder"]["dropout"],
|
45 |
+
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
46 |
+
)
|
47 |
+
|
48 |
+
self.eval()
|
49 |
+
|
50 |
+
def forward(self, image_sequences, audio_sequences):
|
51 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
52 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
53 |
+
|
54 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
55 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
56 |
+
|
57 |
+
# Make them unit vectors
|
58 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
59 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
60 |
+
|
61 |
+
return vision_embeds, audio_embeds
|
62 |
+
|
63 |
+
|
64 |
+
class ResnetBlock2D(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
in_channels: int,
|
68 |
+
out_channels: int,
|
69 |
+
dropout: float = 0.0,
|
70 |
+
norm_num_groups: int = 32,
|
71 |
+
eps: float = 1e-6,
|
72 |
+
act_fn: str = "silu",
|
73 |
+
downsample_factor=2,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
78 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
79 |
+
|
80 |
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
81 |
+
self.dropout = nn.Dropout(dropout)
|
82 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
83 |
+
|
84 |
+
if act_fn == "relu":
|
85 |
+
self.act_fn = nn.ReLU()
|
86 |
+
elif act_fn == "silu":
|
87 |
+
self.act_fn = nn.SiLU()
|
88 |
+
|
89 |
+
if in_channels != out_channels:
|
90 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
91 |
+
else:
|
92 |
+
self.conv_shortcut = None
|
93 |
+
|
94 |
+
if isinstance(downsample_factor, list):
|
95 |
+
downsample_factor = tuple(downsample_factor)
|
96 |
+
|
97 |
+
if downsample_factor == 1:
|
98 |
+
self.downsample_conv = None
|
99 |
+
else:
|
100 |
+
self.downsample_conv = nn.Conv2d(
|
101 |
+
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
102 |
+
)
|
103 |
+
self.pad = (0, 1, 0, 1)
|
104 |
+
if isinstance(downsample_factor, tuple):
|
105 |
+
if downsample_factor[0] == 1:
|
106 |
+
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
107 |
+
elif downsample_factor[1] == 1:
|
108 |
+
self.pad = (1, 1, 0, 1)
|
109 |
+
|
110 |
+
def forward(self, input_tensor):
|
111 |
+
hidden_states = input_tensor
|
112 |
+
|
113 |
+
hidden_states = self.norm1(hidden_states)
|
114 |
+
hidden_states = self.act_fn(hidden_states)
|
115 |
+
|
116 |
+
hidden_states = self.conv1(hidden_states)
|
117 |
+
hidden_states = self.norm2(hidden_states)
|
118 |
+
hidden_states = self.act_fn(hidden_states)
|
119 |
+
|
120 |
+
hidden_states = self.dropout(hidden_states)
|
121 |
+
hidden_states = self.conv2(hidden_states)
|
122 |
+
|
123 |
+
if self.conv_shortcut is not None:
|
124 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
125 |
+
|
126 |
+
hidden_states += input_tensor
|
127 |
+
|
128 |
+
if self.downsample_conv is not None:
|
129 |
+
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
130 |
+
hidden_states = self.downsample_conv(hidden_states)
|
131 |
+
|
132 |
+
return hidden_states
|
133 |
+
|
134 |
+
|
135 |
+
class AttentionBlock2D(nn.Module):
|
136 |
+
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
137 |
+
super().__init__()
|
138 |
+
if not is_xformers_available():
|
139 |
+
raise ModuleNotFoundError(
|
140 |
+
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
141 |
+
)
|
142 |
+
# inner_dim = dim_head * heads
|
143 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
144 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
145 |
+
self.norm3 = nn.LayerNorm(query_dim)
|
146 |
+
|
147 |
+
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
148 |
+
|
149 |
+
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
150 |
+
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
151 |
+
|
152 |
+
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
153 |
+
self.attn._use_memory_efficient_attention_xformers = True
|
154 |
+
|
155 |
+
def forward(self, hidden_states):
|
156 |
+
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
157 |
+
|
158 |
+
batch, channel, height, width = hidden_states.shape
|
159 |
+
residual = hidden_states
|
160 |
+
|
161 |
+
hidden_states = self.norm1(hidden_states)
|
162 |
+
hidden_states = self.conv_in(hidden_states)
|
163 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
164 |
+
|
165 |
+
norm_hidden_states = self.norm2(hidden_states)
|
166 |
+
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
167 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
168 |
+
|
169 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
170 |
+
hidden_states = self.conv_out(hidden_states)
|
171 |
+
|
172 |
+
hidden_states = hidden_states + residual
|
173 |
+
return hidden_states
|
174 |
+
|
175 |
+
|
176 |
+
class DownEncoder2D(nn.Module):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
in_channels=4 * 16,
|
180 |
+
block_out_channels=[64, 128, 256, 256],
|
181 |
+
downsample_factors=[2, 2, 2, 2],
|
182 |
+
layers_per_block=2,
|
183 |
+
norm_num_groups=32,
|
184 |
+
attn_blocks=[1, 1, 1, 1],
|
185 |
+
dropout: float = 0.0,
|
186 |
+
act_fn="silu",
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.layers_per_block = layers_per_block
|
190 |
+
|
191 |
+
# in
|
192 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
193 |
+
|
194 |
+
# down
|
195 |
+
self.down_blocks = nn.ModuleList([])
|
196 |
+
|
197 |
+
output_channels = block_out_channels[0]
|
198 |
+
for i, block_out_channel in enumerate(block_out_channels):
|
199 |
+
input_channels = output_channels
|
200 |
+
output_channels = block_out_channel
|
201 |
+
# is_final_block = i == len(block_out_channels) - 1
|
202 |
+
|
203 |
+
down_block = ResnetBlock2D(
|
204 |
+
in_channels=input_channels,
|
205 |
+
out_channels=output_channels,
|
206 |
+
downsample_factor=downsample_factors[i],
|
207 |
+
norm_num_groups=norm_num_groups,
|
208 |
+
dropout=dropout,
|
209 |
+
act_fn=act_fn,
|
210 |
+
)
|
211 |
+
|
212 |
+
self.down_blocks.append(down_block)
|
213 |
+
|
214 |
+
if attn_blocks[i] == 1:
|
215 |
+
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
216 |
+
self.down_blocks.append(attention_block)
|
217 |
+
|
218 |
+
# out
|
219 |
+
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
220 |
+
self.act_fn_out = nn.ReLU()
|
221 |
+
|
222 |
+
def forward(self, hidden_states):
|
223 |
+
hidden_states = self.conv_in(hidden_states)
|
224 |
+
|
225 |
+
# down
|
226 |
+
for down_block in self.down_blocks:
|
227 |
+
hidden_states = down_block(hidden_states)
|
228 |
+
|
229 |
+
# post-process
|
230 |
+
hidden_states = self.norm_out(hidden_states)
|
231 |
+
hidden_states = self.act_fn_out(hidden_states)
|
232 |
+
|
233 |
+
return hidden_states
|
latentsync/models/syncnet_wav2lip.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
|
2 |
+
# The code here is for ablation study.
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class SyncNetWav2Lip(nn.Module):
|
9 |
+
def __init__(self, act_fn="leaky"):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
# input image sequences: (15, 128, 256)
|
13 |
+
self.visual_encoder = nn.Sequential(
|
14 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
|
15 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
|
16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
17 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
|
19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
22 |
+
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
|
23 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
25 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
|
26 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
27 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
28 |
+
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
|
29 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
30 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
31 |
+
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
|
32 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
33 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
34 |
+
)
|
35 |
+
|
36 |
+
# input audio sequences: (1, 80, 16)
|
37 |
+
self.audio_encoder = nn.Sequential(
|
38 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
39 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
41 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
|
42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
43 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
|
45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
47 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
|
48 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
50 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
51 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
52 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
53 |
+
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
54 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, image_sequences, audio_sequences):
|
58 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
59 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
60 |
+
|
61 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
62 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
63 |
+
|
64 |
+
# Make them unit vectors
|
65 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
66 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
67 |
+
|
68 |
+
return vision_embeds, audio_embeds
|
69 |
+
|
70 |
+
|
71 |
+
class Conv2d(nn.Module):
|
72 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
|
73 |
+
super().__init__(*args, **kwargs)
|
74 |
+
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
75 |
+
if act_fn == "relu":
|
76 |
+
self.act_fn = nn.ReLU()
|
77 |
+
elif act_fn == "tanh":
|
78 |
+
self.act_fn = nn.Tanh()
|
79 |
+
elif act_fn == "silu":
|
80 |
+
self.act_fn = nn.SiLU()
|
81 |
+
elif act_fn == "leaky":
|
82 |
+
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
|
83 |
+
|
84 |
+
self.residual = residual
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
out = self.conv_block(x)
|
88 |
+
if self.residual:
|
89 |
+
out += x
|
90 |
+
return self.act_fn(out)
|
latentsync/models/unet.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.modeling_utils import ModelMixin
|
13 |
+
from diffusers import UNet2DConditionModel
|
14 |
+
from diffusers.utils import BaseOutput, logging
|
15 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
16 |
+
from .unet_blocks import (
|
17 |
+
CrossAttnDownBlock3D,
|
18 |
+
CrossAttnUpBlock3D,
|
19 |
+
DownBlock3D,
|
20 |
+
UNetMidBlock3DCrossAttn,
|
21 |
+
UpBlock3D,
|
22 |
+
get_down_block,
|
23 |
+
get_up_block,
|
24 |
+
)
|
25 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
26 |
+
|
27 |
+
from ..utils.util import zero_rank_log
|
28 |
+
from einops import rearrange
|
29 |
+
from .utils import zero_module
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class UNet3DConditionOutput(BaseOutput):
|
37 |
+
sample: torch.FloatTensor
|
38 |
+
|
39 |
+
|
40 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
41 |
+
_supports_gradient_checkpointing = True
|
42 |
+
|
43 |
+
@register_to_config
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
sample_size: Optional[int] = None,
|
47 |
+
in_channels: int = 4,
|
48 |
+
out_channels: int = 4,
|
49 |
+
center_input_sample: bool = False,
|
50 |
+
flip_sin_to_cos: bool = True,
|
51 |
+
freq_shift: int = 0,
|
52 |
+
down_block_types: Tuple[str] = (
|
53 |
+
"CrossAttnDownBlock3D",
|
54 |
+
"CrossAttnDownBlock3D",
|
55 |
+
"CrossAttnDownBlock3D",
|
56 |
+
"DownBlock3D",
|
57 |
+
),
|
58 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
59 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
60 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
61 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
62 |
+
layers_per_block: int = 2,
|
63 |
+
downsample_padding: int = 1,
|
64 |
+
mid_block_scale_factor: float = 1,
|
65 |
+
act_fn: str = "silu",
|
66 |
+
norm_num_groups: int = 32,
|
67 |
+
norm_eps: float = 1e-5,
|
68 |
+
cross_attention_dim: int = 1280,
|
69 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
70 |
+
dual_cross_attention: bool = False,
|
71 |
+
use_linear_projection: bool = False,
|
72 |
+
class_embed_type: Optional[str] = None,
|
73 |
+
num_class_embeds: Optional[int] = None,
|
74 |
+
upcast_attention: bool = False,
|
75 |
+
resnet_time_scale_shift: str = "default",
|
76 |
+
use_inflated_groupnorm=False,
|
77 |
+
# Additional
|
78 |
+
use_motion_module=False,
|
79 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
80 |
+
motion_module_mid_block=False,
|
81 |
+
motion_module_decoder_only=False,
|
82 |
+
motion_module_type=None,
|
83 |
+
motion_module_kwargs={},
|
84 |
+
unet_use_cross_frame_attention=False,
|
85 |
+
unet_use_temporal_attention=False,
|
86 |
+
add_audio_layer=False,
|
87 |
+
audio_condition_method: str = "cross_attn",
|
88 |
+
custom_audio_layer=False,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.sample_size = sample_size
|
93 |
+
time_embed_dim = block_out_channels[0] * 4
|
94 |
+
self.use_motion_module = use_motion_module
|
95 |
+
self.add_audio_layer = add_audio_layer
|
96 |
+
|
97 |
+
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
|
98 |
+
|
99 |
+
# time
|
100 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
101 |
+
timestep_input_dim = block_out_channels[0]
|
102 |
+
|
103 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
104 |
+
|
105 |
+
# class embedding
|
106 |
+
if class_embed_type is None and num_class_embeds is not None:
|
107 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
108 |
+
elif class_embed_type == "timestep":
|
109 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
110 |
+
elif class_embed_type == "identity":
|
111 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
112 |
+
else:
|
113 |
+
self.class_embedding = None
|
114 |
+
|
115 |
+
self.down_blocks = nn.ModuleList([])
|
116 |
+
self.mid_block = None
|
117 |
+
self.up_blocks = nn.ModuleList([])
|
118 |
+
|
119 |
+
if isinstance(only_cross_attention, bool):
|
120 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
121 |
+
|
122 |
+
if isinstance(attention_head_dim, int):
|
123 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
124 |
+
|
125 |
+
# down
|
126 |
+
output_channel = block_out_channels[0]
|
127 |
+
for i, down_block_type in enumerate(down_block_types):
|
128 |
+
res = 2**i
|
129 |
+
input_channel = output_channel
|
130 |
+
output_channel = block_out_channels[i]
|
131 |
+
is_final_block = i == len(block_out_channels) - 1
|
132 |
+
|
133 |
+
down_block = get_down_block(
|
134 |
+
down_block_type,
|
135 |
+
num_layers=layers_per_block,
|
136 |
+
in_channels=input_channel,
|
137 |
+
out_channels=output_channel,
|
138 |
+
temb_channels=time_embed_dim,
|
139 |
+
add_downsample=not is_final_block,
|
140 |
+
resnet_eps=norm_eps,
|
141 |
+
resnet_act_fn=act_fn,
|
142 |
+
resnet_groups=norm_num_groups,
|
143 |
+
cross_attention_dim=cross_attention_dim,
|
144 |
+
attn_num_head_channels=attention_head_dim[i],
|
145 |
+
downsample_padding=downsample_padding,
|
146 |
+
dual_cross_attention=dual_cross_attention,
|
147 |
+
use_linear_projection=use_linear_projection,
|
148 |
+
only_cross_attention=only_cross_attention[i],
|
149 |
+
upcast_attention=upcast_attention,
|
150 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
151 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
152 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
153 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
154 |
+
use_motion_module=use_motion_module
|
155 |
+
and (res in motion_module_resolutions)
|
156 |
+
and (not motion_module_decoder_only),
|
157 |
+
motion_module_type=motion_module_type,
|
158 |
+
motion_module_kwargs=motion_module_kwargs,
|
159 |
+
add_audio_layer=add_audio_layer,
|
160 |
+
audio_condition_method=audio_condition_method,
|
161 |
+
custom_audio_layer=custom_audio_layer,
|
162 |
+
)
|
163 |
+
self.down_blocks.append(down_block)
|
164 |
+
|
165 |
+
# mid
|
166 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
167 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
168 |
+
in_channels=block_out_channels[-1],
|
169 |
+
temb_channels=time_embed_dim,
|
170 |
+
resnet_eps=norm_eps,
|
171 |
+
resnet_act_fn=act_fn,
|
172 |
+
output_scale_factor=mid_block_scale_factor,
|
173 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
174 |
+
cross_attention_dim=cross_attention_dim,
|
175 |
+
attn_num_head_channels=attention_head_dim[-1],
|
176 |
+
resnet_groups=norm_num_groups,
|
177 |
+
dual_cross_attention=dual_cross_attention,
|
178 |
+
use_linear_projection=use_linear_projection,
|
179 |
+
upcast_attention=upcast_attention,
|
180 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
181 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
182 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
183 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
184 |
+
motion_module_type=motion_module_type,
|
185 |
+
motion_module_kwargs=motion_module_kwargs,
|
186 |
+
add_audio_layer=add_audio_layer,
|
187 |
+
audio_condition_method=audio_condition_method,
|
188 |
+
custom_audio_layer=custom_audio_layer,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
192 |
+
|
193 |
+
# count how many layers upsample the videos
|
194 |
+
self.num_upsamplers = 0
|
195 |
+
|
196 |
+
# up
|
197 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
198 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
199 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
200 |
+
output_channel = reversed_block_out_channels[0]
|
201 |
+
for i, up_block_type in enumerate(up_block_types):
|
202 |
+
res = 2 ** (3 - i)
|
203 |
+
is_final_block = i == len(block_out_channels) - 1
|
204 |
+
|
205 |
+
prev_output_channel = output_channel
|
206 |
+
output_channel = reversed_block_out_channels[i]
|
207 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
208 |
+
|
209 |
+
# add upsample block for all BUT final layer
|
210 |
+
if not is_final_block:
|
211 |
+
add_upsample = True
|
212 |
+
self.num_upsamplers += 1
|
213 |
+
else:
|
214 |
+
add_upsample = False
|
215 |
+
|
216 |
+
up_block = get_up_block(
|
217 |
+
up_block_type,
|
218 |
+
num_layers=layers_per_block + 1,
|
219 |
+
in_channels=input_channel,
|
220 |
+
out_channels=output_channel,
|
221 |
+
prev_output_channel=prev_output_channel,
|
222 |
+
temb_channels=time_embed_dim,
|
223 |
+
add_upsample=add_upsample,
|
224 |
+
resnet_eps=norm_eps,
|
225 |
+
resnet_act_fn=act_fn,
|
226 |
+
resnet_groups=norm_num_groups,
|
227 |
+
cross_attention_dim=cross_attention_dim,
|
228 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
229 |
+
dual_cross_attention=dual_cross_attention,
|
230 |
+
use_linear_projection=use_linear_projection,
|
231 |
+
only_cross_attention=only_cross_attention[i],
|
232 |
+
upcast_attention=upcast_attention,
|
233 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
234 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
235 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
236 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
237 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
238 |
+
motion_module_type=motion_module_type,
|
239 |
+
motion_module_kwargs=motion_module_kwargs,
|
240 |
+
add_audio_layer=add_audio_layer,
|
241 |
+
audio_condition_method=audio_condition_method,
|
242 |
+
custom_audio_layer=custom_audio_layer,
|
243 |
+
)
|
244 |
+
self.up_blocks.append(up_block)
|
245 |
+
prev_output_channel = output_channel
|
246 |
+
|
247 |
+
# out
|
248 |
+
if use_inflated_groupnorm:
|
249 |
+
self.conv_norm_out = InflatedGroupNorm(
|
250 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
self.conv_norm_out = nn.GroupNorm(
|
254 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
255 |
+
)
|
256 |
+
self.conv_act = nn.SiLU()
|
257 |
+
|
258 |
+
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
|
259 |
+
|
260 |
+
def set_attention_slice(self, slice_size):
|
261 |
+
r"""
|
262 |
+
Enable sliced attention computation.
|
263 |
+
|
264 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
265 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
269 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
270 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
271 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
272 |
+
must be a multiple of `slice_size`.
|
273 |
+
"""
|
274 |
+
sliceable_head_dims = []
|
275 |
+
|
276 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
277 |
+
if hasattr(module, "set_attention_slice"):
|
278 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
279 |
+
|
280 |
+
for child in module.children():
|
281 |
+
fn_recursive_retrieve_slicable_dims(child)
|
282 |
+
|
283 |
+
# retrieve number of attention layers
|
284 |
+
for module in self.children():
|
285 |
+
fn_recursive_retrieve_slicable_dims(module)
|
286 |
+
|
287 |
+
num_slicable_layers = len(sliceable_head_dims)
|
288 |
+
|
289 |
+
if slice_size == "auto":
|
290 |
+
# half the attention head size is usually a good trade-off between
|
291 |
+
# speed and memory
|
292 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
293 |
+
elif slice_size == "max":
|
294 |
+
# make smallest slice possible
|
295 |
+
slice_size = num_slicable_layers * [1]
|
296 |
+
|
297 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
298 |
+
|
299 |
+
if len(slice_size) != len(sliceable_head_dims):
|
300 |
+
raise ValueError(
|
301 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
302 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
303 |
+
)
|
304 |
+
|
305 |
+
for i in range(len(slice_size)):
|
306 |
+
size = slice_size[i]
|
307 |
+
dim = sliceable_head_dims[i]
|
308 |
+
if size is not None and size > dim:
|
309 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
310 |
+
|
311 |
+
# Recursively walk through all the children.
|
312 |
+
# Any children which exposes the set_attention_slice method
|
313 |
+
# gets the message
|
314 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
315 |
+
if hasattr(module, "set_attention_slice"):
|
316 |
+
module.set_attention_slice(slice_size.pop())
|
317 |
+
|
318 |
+
for child in module.children():
|
319 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
320 |
+
|
321 |
+
reversed_slice_size = list(reversed(slice_size))
|
322 |
+
for module in self.children():
|
323 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
324 |
+
|
325 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
326 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
327 |
+
module.gradient_checkpointing = value
|
328 |
+
|
329 |
+
def forward(
|
330 |
+
self,
|
331 |
+
sample: torch.FloatTensor,
|
332 |
+
timestep: Union[torch.Tensor, float, int],
|
333 |
+
encoder_hidden_states: torch.Tensor,
|
334 |
+
class_labels: Optional[torch.Tensor] = None,
|
335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
336 |
+
# support controlnet
|
337 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
338 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
339 |
+
return_dict: bool = True,
|
340 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
341 |
+
r"""
|
342 |
+
Args:
|
343 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
344 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
345 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
346 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
347 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
351 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
352 |
+
returning a tuple, the first element is the sample tensor.
|
353 |
+
"""
|
354 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
355 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
356 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
357 |
+
# on the fly if necessary.
|
358 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
359 |
+
|
360 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
361 |
+
forward_upsample_size = False
|
362 |
+
upsample_size = None
|
363 |
+
|
364 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
365 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
366 |
+
forward_upsample_size = True
|
367 |
+
|
368 |
+
# prepare attention_mask
|
369 |
+
if attention_mask is not None:
|
370 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
371 |
+
attention_mask = attention_mask.unsqueeze(1)
|
372 |
+
|
373 |
+
# center input if necessary
|
374 |
+
if self.config.center_input_sample:
|
375 |
+
sample = 2 * sample - 1.0
|
376 |
+
|
377 |
+
# time
|
378 |
+
timesteps = timestep
|
379 |
+
if not torch.is_tensor(timesteps):
|
380 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
381 |
+
is_mps = sample.device.type == "mps"
|
382 |
+
if isinstance(timestep, float):
|
383 |
+
dtype = torch.float32 if is_mps else torch.float64
|
384 |
+
else:
|
385 |
+
dtype = torch.int32 if is_mps else torch.int64
|
386 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
387 |
+
elif len(timesteps.shape) == 0:
|
388 |
+
timesteps = timesteps[None].to(sample.device)
|
389 |
+
|
390 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
391 |
+
timesteps = timesteps.expand(sample.shape[0])
|
392 |
+
|
393 |
+
t_emb = self.time_proj(timesteps)
|
394 |
+
|
395 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
396 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
397 |
+
# there might be better ways to encapsulate this.
|
398 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
399 |
+
emb = self.time_embedding(t_emb)
|
400 |
+
|
401 |
+
if self.class_embedding is not None:
|
402 |
+
if class_labels is None:
|
403 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
404 |
+
|
405 |
+
if self.config.class_embed_type == "timestep":
|
406 |
+
class_labels = self.time_proj(class_labels)
|
407 |
+
|
408 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
409 |
+
emb = emb + class_emb
|
410 |
+
|
411 |
+
# pre-process
|
412 |
+
sample = self.conv_in(sample)
|
413 |
+
|
414 |
+
# down
|
415 |
+
down_block_res_samples = (sample,)
|
416 |
+
for downsample_block in self.down_blocks:
|
417 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
418 |
+
sample, res_samples = downsample_block(
|
419 |
+
hidden_states=sample,
|
420 |
+
temb=emb,
|
421 |
+
encoder_hidden_states=encoder_hidden_states,
|
422 |
+
attention_mask=attention_mask,
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
sample, res_samples = downsample_block(
|
426 |
+
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
427 |
+
)
|
428 |
+
|
429 |
+
down_block_res_samples += res_samples
|
430 |
+
|
431 |
+
# support controlnet
|
432 |
+
down_block_res_samples = list(down_block_res_samples)
|
433 |
+
if down_block_additional_residuals is not None:
|
434 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
435 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
436 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
437 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
438 |
+
|
439 |
+
# mid
|
440 |
+
sample = self.mid_block(
|
441 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
442 |
+
)
|
443 |
+
|
444 |
+
# support controlnet
|
445 |
+
if mid_block_additional_residual is not None:
|
446 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
447 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
448 |
+
sample = sample + mid_block_additional_residual
|
449 |
+
|
450 |
+
# up
|
451 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
452 |
+
is_final_block = i == len(self.up_blocks) - 1
|
453 |
+
|
454 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
455 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
456 |
+
|
457 |
+
# if we have not reached the final block and need to forward the
|
458 |
+
# upsample size, we do it here
|
459 |
+
if not is_final_block and forward_upsample_size:
|
460 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
461 |
+
|
462 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
463 |
+
sample = upsample_block(
|
464 |
+
hidden_states=sample,
|
465 |
+
temb=emb,
|
466 |
+
res_hidden_states_tuple=res_samples,
|
467 |
+
encoder_hidden_states=encoder_hidden_states,
|
468 |
+
upsample_size=upsample_size,
|
469 |
+
attention_mask=attention_mask,
|
470 |
+
)
|
471 |
+
else:
|
472 |
+
sample = upsample_block(
|
473 |
+
hidden_states=sample,
|
474 |
+
temb=emb,
|
475 |
+
res_hidden_states_tuple=res_samples,
|
476 |
+
upsample_size=upsample_size,
|
477 |
+
encoder_hidden_states=encoder_hidden_states,
|
478 |
+
)
|
479 |
+
|
480 |
+
# post-process
|
481 |
+
sample = self.conv_norm_out(sample)
|
482 |
+
sample = self.conv_act(sample)
|
483 |
+
sample = self.conv_out(sample)
|
484 |
+
|
485 |
+
if not return_dict:
|
486 |
+
return (sample,)
|
487 |
+
|
488 |
+
return UNet3DConditionOutput(sample=sample)
|
489 |
+
|
490 |
+
def load_state_dict(self, state_dict, strict=True):
|
491 |
+
# If the loaded checkpoint's in_channels or out_channels are different from config
|
492 |
+
temp_state_dict = copy.deepcopy(state_dict)
|
493 |
+
if temp_state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
|
494 |
+
del temp_state_dict["conv_in.weight"]
|
495 |
+
del temp_state_dict["conv_in.bias"]
|
496 |
+
if temp_state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
|
497 |
+
del temp_state_dict["conv_out.weight"]
|
498 |
+
del temp_state_dict["conv_out.bias"]
|
499 |
+
|
500 |
+
# If the loaded checkpoint's cross_attention_dim is different from config
|
501 |
+
keys_to_remove = []
|
502 |
+
for key in temp_state_dict:
|
503 |
+
if "audio_cross_attn.attn.to_k." in key or "audio_cross_attn.attn.to_v." in key:
|
504 |
+
if temp_state_dict[key].shape[1] != self.config.cross_attention_dim:
|
505 |
+
keys_to_remove.append(key)
|
506 |
+
|
507 |
+
for key in keys_to_remove:
|
508 |
+
del temp_state_dict[key]
|
509 |
+
|
510 |
+
return super().load_state_dict(state_dict=temp_state_dict, strict=strict)
|
511 |
+
|
512 |
+
@classmethod
|
513 |
+
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
|
514 |
+
unet = cls.from_config(model_config).to(device)
|
515 |
+
if ckpt_path != "":
|
516 |
+
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
|
517 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
518 |
+
if "global_step" in ckpt:
|
519 |
+
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
|
520 |
+
resume_global_step = ckpt["global_step"]
|
521 |
+
else:
|
522 |
+
resume_global_step = 0
|
523 |
+
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
|
524 |
+
unet.load_state_dict(state_dict, strict=False)
|
525 |
+
else:
|
526 |
+
resume_global_step = 0
|
527 |
+
|
528 |
+
return unet, resume_global_step
|
latentsync/models/unet_blocks.py
ADDED
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .attention import Transformer3DModel
|
7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
8 |
+
from .motion_module import get_motion_module
|
9 |
+
|
10 |
+
|
11 |
+
def get_down_block(
|
12 |
+
down_block_type,
|
13 |
+
num_layers,
|
14 |
+
in_channels,
|
15 |
+
out_channels,
|
16 |
+
temb_channels,
|
17 |
+
add_downsample,
|
18 |
+
resnet_eps,
|
19 |
+
resnet_act_fn,
|
20 |
+
attn_num_head_channels,
|
21 |
+
resnet_groups=None,
|
22 |
+
cross_attention_dim=None,
|
23 |
+
downsample_padding=None,
|
24 |
+
dual_cross_attention=False,
|
25 |
+
use_linear_projection=False,
|
26 |
+
only_cross_attention=False,
|
27 |
+
upcast_attention=False,
|
28 |
+
resnet_time_scale_shift="default",
|
29 |
+
unet_use_cross_frame_attention=False,
|
30 |
+
unet_use_temporal_attention=False,
|
31 |
+
use_inflated_groupnorm=False,
|
32 |
+
use_motion_module=None,
|
33 |
+
motion_module_type=None,
|
34 |
+
motion_module_kwargs=None,
|
35 |
+
add_audio_layer=False,
|
36 |
+
audio_condition_method="cross_attn",
|
37 |
+
custom_audio_layer=False,
|
38 |
+
):
|
39 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
40 |
+
if down_block_type == "DownBlock3D":
|
41 |
+
return DownBlock3D(
|
42 |
+
num_layers=num_layers,
|
43 |
+
in_channels=in_channels,
|
44 |
+
out_channels=out_channels,
|
45 |
+
temb_channels=temb_channels,
|
46 |
+
add_downsample=add_downsample,
|
47 |
+
resnet_eps=resnet_eps,
|
48 |
+
resnet_act_fn=resnet_act_fn,
|
49 |
+
resnet_groups=resnet_groups,
|
50 |
+
downsample_padding=downsample_padding,
|
51 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
52 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
53 |
+
use_motion_module=use_motion_module,
|
54 |
+
motion_module_type=motion_module_type,
|
55 |
+
motion_module_kwargs=motion_module_kwargs,
|
56 |
+
)
|
57 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
58 |
+
if cross_attention_dim is None:
|
59 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
60 |
+
return CrossAttnDownBlock3D(
|
61 |
+
num_layers=num_layers,
|
62 |
+
in_channels=in_channels,
|
63 |
+
out_channels=out_channels,
|
64 |
+
temb_channels=temb_channels,
|
65 |
+
add_downsample=add_downsample,
|
66 |
+
resnet_eps=resnet_eps,
|
67 |
+
resnet_act_fn=resnet_act_fn,
|
68 |
+
resnet_groups=resnet_groups,
|
69 |
+
downsample_padding=downsample_padding,
|
70 |
+
cross_attention_dim=cross_attention_dim,
|
71 |
+
attn_num_head_channels=attn_num_head_channels,
|
72 |
+
dual_cross_attention=dual_cross_attention,
|
73 |
+
use_linear_projection=use_linear_projection,
|
74 |
+
only_cross_attention=only_cross_attention,
|
75 |
+
upcast_attention=upcast_attention,
|
76 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
77 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
78 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
79 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
80 |
+
use_motion_module=use_motion_module,
|
81 |
+
motion_module_type=motion_module_type,
|
82 |
+
motion_module_kwargs=motion_module_kwargs,
|
83 |
+
add_audio_layer=add_audio_layer,
|
84 |
+
audio_condition_method=audio_condition_method,
|
85 |
+
custom_audio_layer=custom_audio_layer,
|
86 |
+
)
|
87 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
88 |
+
|
89 |
+
|
90 |
+
def get_up_block(
|
91 |
+
up_block_type,
|
92 |
+
num_layers,
|
93 |
+
in_channels,
|
94 |
+
out_channels,
|
95 |
+
prev_output_channel,
|
96 |
+
temb_channels,
|
97 |
+
add_upsample,
|
98 |
+
resnet_eps,
|
99 |
+
resnet_act_fn,
|
100 |
+
attn_num_head_channels,
|
101 |
+
resnet_groups=None,
|
102 |
+
cross_attention_dim=None,
|
103 |
+
dual_cross_attention=False,
|
104 |
+
use_linear_projection=False,
|
105 |
+
only_cross_attention=False,
|
106 |
+
upcast_attention=False,
|
107 |
+
resnet_time_scale_shift="default",
|
108 |
+
unet_use_cross_frame_attention=False,
|
109 |
+
unet_use_temporal_attention=False,
|
110 |
+
use_inflated_groupnorm=False,
|
111 |
+
use_motion_module=None,
|
112 |
+
motion_module_type=None,
|
113 |
+
motion_module_kwargs=None,
|
114 |
+
add_audio_layer=False,
|
115 |
+
audio_condition_method="cross_attn",
|
116 |
+
custom_audio_layer=False,
|
117 |
+
):
|
118 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
119 |
+
if up_block_type == "UpBlock3D":
|
120 |
+
return UpBlock3D(
|
121 |
+
num_layers=num_layers,
|
122 |
+
in_channels=in_channels,
|
123 |
+
out_channels=out_channels,
|
124 |
+
prev_output_channel=prev_output_channel,
|
125 |
+
temb_channels=temb_channels,
|
126 |
+
add_upsample=add_upsample,
|
127 |
+
resnet_eps=resnet_eps,
|
128 |
+
resnet_act_fn=resnet_act_fn,
|
129 |
+
resnet_groups=resnet_groups,
|
130 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
131 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
132 |
+
use_motion_module=use_motion_module,
|
133 |
+
motion_module_type=motion_module_type,
|
134 |
+
motion_module_kwargs=motion_module_kwargs,
|
135 |
+
)
|
136 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
137 |
+
if cross_attention_dim is None:
|
138 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
139 |
+
return CrossAttnUpBlock3D(
|
140 |
+
num_layers=num_layers,
|
141 |
+
in_channels=in_channels,
|
142 |
+
out_channels=out_channels,
|
143 |
+
prev_output_channel=prev_output_channel,
|
144 |
+
temb_channels=temb_channels,
|
145 |
+
add_upsample=add_upsample,
|
146 |
+
resnet_eps=resnet_eps,
|
147 |
+
resnet_act_fn=resnet_act_fn,
|
148 |
+
resnet_groups=resnet_groups,
|
149 |
+
cross_attention_dim=cross_attention_dim,
|
150 |
+
attn_num_head_channels=attn_num_head_channels,
|
151 |
+
dual_cross_attention=dual_cross_attention,
|
152 |
+
use_linear_projection=use_linear_projection,
|
153 |
+
only_cross_attention=only_cross_attention,
|
154 |
+
upcast_attention=upcast_attention,
|
155 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
156 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
157 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
158 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
159 |
+
use_motion_module=use_motion_module,
|
160 |
+
motion_module_type=motion_module_type,
|
161 |
+
motion_module_kwargs=motion_module_kwargs,
|
162 |
+
add_audio_layer=add_audio_layer,
|
163 |
+
audio_condition_method=audio_condition_method,
|
164 |
+
custom_audio_layer=custom_audio_layer,
|
165 |
+
)
|
166 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
167 |
+
|
168 |
+
|
169 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
in_channels: int,
|
173 |
+
temb_channels: int,
|
174 |
+
dropout: float = 0.0,
|
175 |
+
num_layers: int = 1,
|
176 |
+
resnet_eps: float = 1e-6,
|
177 |
+
resnet_time_scale_shift: str = "default",
|
178 |
+
resnet_act_fn: str = "swish",
|
179 |
+
resnet_groups: int = 32,
|
180 |
+
resnet_pre_norm: bool = True,
|
181 |
+
attn_num_head_channels=1,
|
182 |
+
output_scale_factor=1.0,
|
183 |
+
cross_attention_dim=1280,
|
184 |
+
dual_cross_attention=False,
|
185 |
+
use_linear_projection=False,
|
186 |
+
upcast_attention=False,
|
187 |
+
unet_use_cross_frame_attention=False,
|
188 |
+
unet_use_temporal_attention=False,
|
189 |
+
use_inflated_groupnorm=False,
|
190 |
+
use_motion_module=None,
|
191 |
+
motion_module_type=None,
|
192 |
+
motion_module_kwargs=None,
|
193 |
+
add_audio_layer=False,
|
194 |
+
audio_condition_method="cross_attn",
|
195 |
+
custom_audio_layer: bool = False,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
|
199 |
+
self.has_cross_attention = True
|
200 |
+
self.attn_num_head_channels = attn_num_head_channels
|
201 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
202 |
+
|
203 |
+
# there is always at least one resnet
|
204 |
+
resnets = [
|
205 |
+
ResnetBlock3D(
|
206 |
+
in_channels=in_channels,
|
207 |
+
out_channels=in_channels,
|
208 |
+
temb_channels=temb_channels,
|
209 |
+
eps=resnet_eps,
|
210 |
+
groups=resnet_groups,
|
211 |
+
dropout=dropout,
|
212 |
+
time_embedding_norm=resnet_time_scale_shift,
|
213 |
+
non_linearity=resnet_act_fn,
|
214 |
+
output_scale_factor=output_scale_factor,
|
215 |
+
pre_norm=resnet_pre_norm,
|
216 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
217 |
+
)
|
218 |
+
]
|
219 |
+
attentions = []
|
220 |
+
audio_attentions = []
|
221 |
+
motion_modules = []
|
222 |
+
|
223 |
+
for _ in range(num_layers):
|
224 |
+
if dual_cross_attention:
|
225 |
+
raise NotImplementedError
|
226 |
+
attentions.append(
|
227 |
+
Transformer3DModel(
|
228 |
+
attn_num_head_channels,
|
229 |
+
in_channels // attn_num_head_channels,
|
230 |
+
in_channels=in_channels,
|
231 |
+
num_layers=1,
|
232 |
+
cross_attention_dim=cross_attention_dim,
|
233 |
+
norm_num_groups=resnet_groups,
|
234 |
+
use_linear_projection=use_linear_projection,
|
235 |
+
upcast_attention=upcast_attention,
|
236 |
+
use_motion_module=use_motion_module,
|
237 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
238 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
239 |
+
add_audio_layer=add_audio_layer,
|
240 |
+
audio_condition_method=audio_condition_method,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
audio_attentions.append(
|
244 |
+
Transformer3DModel(
|
245 |
+
attn_num_head_channels,
|
246 |
+
in_channels // attn_num_head_channels,
|
247 |
+
in_channels=in_channels,
|
248 |
+
num_layers=1,
|
249 |
+
cross_attention_dim=cross_attention_dim,
|
250 |
+
norm_num_groups=resnet_groups,
|
251 |
+
use_linear_projection=use_linear_projection,
|
252 |
+
upcast_attention=upcast_attention,
|
253 |
+
use_motion_module=use_motion_module,
|
254 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
255 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
256 |
+
add_audio_layer=add_audio_layer,
|
257 |
+
audio_condition_method=audio_condition_method,
|
258 |
+
custom_audio_layer=True,
|
259 |
+
)
|
260 |
+
if custom_audio_layer
|
261 |
+
else None
|
262 |
+
)
|
263 |
+
motion_modules.append(
|
264 |
+
get_motion_module(
|
265 |
+
in_channels=in_channels,
|
266 |
+
motion_module_type=motion_module_type,
|
267 |
+
motion_module_kwargs=motion_module_kwargs,
|
268 |
+
)
|
269 |
+
if use_motion_module
|
270 |
+
else None
|
271 |
+
)
|
272 |
+
resnets.append(
|
273 |
+
ResnetBlock3D(
|
274 |
+
in_channels=in_channels,
|
275 |
+
out_channels=in_channels,
|
276 |
+
temb_channels=temb_channels,
|
277 |
+
eps=resnet_eps,
|
278 |
+
groups=resnet_groups,
|
279 |
+
dropout=dropout,
|
280 |
+
time_embedding_norm=resnet_time_scale_shift,
|
281 |
+
non_linearity=resnet_act_fn,
|
282 |
+
output_scale_factor=output_scale_factor,
|
283 |
+
pre_norm=resnet_pre_norm,
|
284 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
285 |
+
)
|
286 |
+
)
|
287 |
+
|
288 |
+
self.attentions = nn.ModuleList(attentions)
|
289 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
290 |
+
self.resnets = nn.ModuleList(resnets)
|
291 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
292 |
+
|
293 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
294 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
295 |
+
for attn, audio_attn, resnet, motion_module in zip(
|
296 |
+
self.attentions, self.audio_attentions, self.resnets[1:], self.motion_modules
|
297 |
+
):
|
298 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
299 |
+
hidden_states = (
|
300 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
301 |
+
if audio_attn is not None
|
302 |
+
else hidden_states
|
303 |
+
)
|
304 |
+
hidden_states = (
|
305 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
306 |
+
if motion_module is not None
|
307 |
+
else hidden_states
|
308 |
+
)
|
309 |
+
hidden_states = resnet(hidden_states, temb)
|
310 |
+
|
311 |
+
return hidden_states
|
312 |
+
|
313 |
+
|
314 |
+
class CrossAttnDownBlock3D(nn.Module):
|
315 |
+
def __init__(
|
316 |
+
self,
|
317 |
+
in_channels: int,
|
318 |
+
out_channels: int,
|
319 |
+
temb_channels: int,
|
320 |
+
dropout: float = 0.0,
|
321 |
+
num_layers: int = 1,
|
322 |
+
resnet_eps: float = 1e-6,
|
323 |
+
resnet_time_scale_shift: str = "default",
|
324 |
+
resnet_act_fn: str = "swish",
|
325 |
+
resnet_groups: int = 32,
|
326 |
+
resnet_pre_norm: bool = True,
|
327 |
+
attn_num_head_channels=1,
|
328 |
+
cross_attention_dim=1280,
|
329 |
+
output_scale_factor=1.0,
|
330 |
+
downsample_padding=1,
|
331 |
+
add_downsample=True,
|
332 |
+
dual_cross_attention=False,
|
333 |
+
use_linear_projection=False,
|
334 |
+
only_cross_attention=False,
|
335 |
+
upcast_attention=False,
|
336 |
+
unet_use_cross_frame_attention=False,
|
337 |
+
unet_use_temporal_attention=False,
|
338 |
+
use_inflated_groupnorm=False,
|
339 |
+
use_motion_module=None,
|
340 |
+
motion_module_type=None,
|
341 |
+
motion_module_kwargs=None,
|
342 |
+
add_audio_layer=False,
|
343 |
+
audio_condition_method="cross_attn",
|
344 |
+
custom_audio_layer: bool = False,
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
resnets = []
|
348 |
+
attentions = []
|
349 |
+
audio_attentions = []
|
350 |
+
motion_modules = []
|
351 |
+
|
352 |
+
self.has_cross_attention = True
|
353 |
+
self.attn_num_head_channels = attn_num_head_channels
|
354 |
+
|
355 |
+
for i in range(num_layers):
|
356 |
+
in_channels = in_channels if i == 0 else out_channels
|
357 |
+
resnets.append(
|
358 |
+
ResnetBlock3D(
|
359 |
+
in_channels=in_channels,
|
360 |
+
out_channels=out_channels,
|
361 |
+
temb_channels=temb_channels,
|
362 |
+
eps=resnet_eps,
|
363 |
+
groups=resnet_groups,
|
364 |
+
dropout=dropout,
|
365 |
+
time_embedding_norm=resnet_time_scale_shift,
|
366 |
+
non_linearity=resnet_act_fn,
|
367 |
+
output_scale_factor=output_scale_factor,
|
368 |
+
pre_norm=resnet_pre_norm,
|
369 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
370 |
+
)
|
371 |
+
)
|
372 |
+
if dual_cross_attention:
|
373 |
+
raise NotImplementedError
|
374 |
+
attentions.append(
|
375 |
+
Transformer3DModel(
|
376 |
+
attn_num_head_channels,
|
377 |
+
out_channels // attn_num_head_channels,
|
378 |
+
in_channels=out_channels,
|
379 |
+
num_layers=1,
|
380 |
+
cross_attention_dim=cross_attention_dim,
|
381 |
+
norm_num_groups=resnet_groups,
|
382 |
+
use_linear_projection=use_linear_projection,
|
383 |
+
only_cross_attention=only_cross_attention,
|
384 |
+
upcast_attention=upcast_attention,
|
385 |
+
use_motion_module=use_motion_module,
|
386 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
387 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
388 |
+
add_audio_layer=add_audio_layer,
|
389 |
+
audio_condition_method=audio_condition_method,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
audio_attentions.append(
|
393 |
+
Transformer3DModel(
|
394 |
+
attn_num_head_channels,
|
395 |
+
out_channels // attn_num_head_channels,
|
396 |
+
in_channels=out_channels,
|
397 |
+
num_layers=1,
|
398 |
+
cross_attention_dim=cross_attention_dim,
|
399 |
+
norm_num_groups=resnet_groups,
|
400 |
+
use_linear_projection=use_linear_projection,
|
401 |
+
only_cross_attention=only_cross_attention,
|
402 |
+
upcast_attention=upcast_attention,
|
403 |
+
use_motion_module=use_motion_module,
|
404 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
405 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
406 |
+
add_audio_layer=add_audio_layer,
|
407 |
+
audio_condition_method=audio_condition_method,
|
408 |
+
custom_audio_layer=True,
|
409 |
+
)
|
410 |
+
if custom_audio_layer
|
411 |
+
else None
|
412 |
+
)
|
413 |
+
motion_modules.append(
|
414 |
+
get_motion_module(
|
415 |
+
in_channels=out_channels,
|
416 |
+
motion_module_type=motion_module_type,
|
417 |
+
motion_module_kwargs=motion_module_kwargs,
|
418 |
+
)
|
419 |
+
if use_motion_module
|
420 |
+
else None
|
421 |
+
)
|
422 |
+
|
423 |
+
self.attentions = nn.ModuleList(attentions)
|
424 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
425 |
+
self.resnets = nn.ModuleList(resnets)
|
426 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
427 |
+
|
428 |
+
if add_downsample:
|
429 |
+
self.downsamplers = nn.ModuleList(
|
430 |
+
[
|
431 |
+
Downsample3D(
|
432 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
433 |
+
)
|
434 |
+
]
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
self.downsamplers = None
|
438 |
+
|
439 |
+
self.gradient_checkpointing = False
|
440 |
+
|
441 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
442 |
+
output_states = ()
|
443 |
+
|
444 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
445 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
446 |
+
):
|
447 |
+
if self.training and self.gradient_checkpointing:
|
448 |
+
|
449 |
+
def create_custom_forward(module, return_dict=None):
|
450 |
+
def custom_forward(*inputs):
|
451 |
+
if return_dict is not None:
|
452 |
+
return module(*inputs, return_dict=return_dict)
|
453 |
+
else:
|
454 |
+
return module(*inputs)
|
455 |
+
|
456 |
+
return custom_forward
|
457 |
+
|
458 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
459 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
460 |
+
create_custom_forward(attn, return_dict=False),
|
461 |
+
hidden_states,
|
462 |
+
encoder_hidden_states,
|
463 |
+
)[0]
|
464 |
+
if motion_module is not None:
|
465 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
466 |
+
create_custom_forward(motion_module),
|
467 |
+
hidden_states.requires_grad_(),
|
468 |
+
temb,
|
469 |
+
encoder_hidden_states,
|
470 |
+
)
|
471 |
+
|
472 |
+
else:
|
473 |
+
hidden_states = resnet(hidden_states, temb)
|
474 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
475 |
+
|
476 |
+
hidden_states = (
|
477 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
478 |
+
if audio_attn is not None
|
479 |
+
else hidden_states
|
480 |
+
)
|
481 |
+
|
482 |
+
# add motion module
|
483 |
+
hidden_states = (
|
484 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
485 |
+
if motion_module is not None
|
486 |
+
else hidden_states
|
487 |
+
)
|
488 |
+
|
489 |
+
output_states += (hidden_states,)
|
490 |
+
|
491 |
+
if self.downsamplers is not None:
|
492 |
+
for downsampler in self.downsamplers:
|
493 |
+
hidden_states = downsampler(hidden_states)
|
494 |
+
|
495 |
+
output_states += (hidden_states,)
|
496 |
+
|
497 |
+
return hidden_states, output_states
|
498 |
+
|
499 |
+
|
500 |
+
class DownBlock3D(nn.Module):
|
501 |
+
def __init__(
|
502 |
+
self,
|
503 |
+
in_channels: int,
|
504 |
+
out_channels: int,
|
505 |
+
temb_channels: int,
|
506 |
+
dropout: float = 0.0,
|
507 |
+
num_layers: int = 1,
|
508 |
+
resnet_eps: float = 1e-6,
|
509 |
+
resnet_time_scale_shift: str = "default",
|
510 |
+
resnet_act_fn: str = "swish",
|
511 |
+
resnet_groups: int = 32,
|
512 |
+
resnet_pre_norm: bool = True,
|
513 |
+
output_scale_factor=1.0,
|
514 |
+
add_downsample=True,
|
515 |
+
downsample_padding=1,
|
516 |
+
use_inflated_groupnorm=False,
|
517 |
+
use_motion_module=None,
|
518 |
+
motion_module_type=None,
|
519 |
+
motion_module_kwargs=None,
|
520 |
+
):
|
521 |
+
super().__init__()
|
522 |
+
resnets = []
|
523 |
+
motion_modules = []
|
524 |
+
|
525 |
+
for i in range(num_layers):
|
526 |
+
in_channels = in_channels if i == 0 else out_channels
|
527 |
+
resnets.append(
|
528 |
+
ResnetBlock3D(
|
529 |
+
in_channels=in_channels,
|
530 |
+
out_channels=out_channels,
|
531 |
+
temb_channels=temb_channels,
|
532 |
+
eps=resnet_eps,
|
533 |
+
groups=resnet_groups,
|
534 |
+
dropout=dropout,
|
535 |
+
time_embedding_norm=resnet_time_scale_shift,
|
536 |
+
non_linearity=resnet_act_fn,
|
537 |
+
output_scale_factor=output_scale_factor,
|
538 |
+
pre_norm=resnet_pre_norm,
|
539 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
540 |
+
)
|
541 |
+
)
|
542 |
+
motion_modules.append(
|
543 |
+
get_motion_module(
|
544 |
+
in_channels=out_channels,
|
545 |
+
motion_module_type=motion_module_type,
|
546 |
+
motion_module_kwargs=motion_module_kwargs,
|
547 |
+
)
|
548 |
+
if use_motion_module
|
549 |
+
else None
|
550 |
+
)
|
551 |
+
|
552 |
+
self.resnets = nn.ModuleList(resnets)
|
553 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
554 |
+
|
555 |
+
if add_downsample:
|
556 |
+
self.downsamplers = nn.ModuleList(
|
557 |
+
[
|
558 |
+
Downsample3D(
|
559 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
560 |
+
)
|
561 |
+
]
|
562 |
+
)
|
563 |
+
else:
|
564 |
+
self.downsamplers = None
|
565 |
+
|
566 |
+
self.gradient_checkpointing = False
|
567 |
+
|
568 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
569 |
+
output_states = ()
|
570 |
+
|
571 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
572 |
+
if self.training and self.gradient_checkpointing:
|
573 |
+
|
574 |
+
def create_custom_forward(module):
|
575 |
+
def custom_forward(*inputs):
|
576 |
+
return module(*inputs)
|
577 |
+
|
578 |
+
return custom_forward
|
579 |
+
|
580 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
581 |
+
if motion_module is not None:
|
582 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
583 |
+
create_custom_forward(motion_module),
|
584 |
+
hidden_states.requires_grad_(),
|
585 |
+
temb,
|
586 |
+
encoder_hidden_states,
|
587 |
+
)
|
588 |
+
else:
|
589 |
+
hidden_states = resnet(hidden_states, temb)
|
590 |
+
|
591 |
+
# add motion module
|
592 |
+
hidden_states = (
|
593 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
594 |
+
if motion_module is not None
|
595 |
+
else hidden_states
|
596 |
+
)
|
597 |
+
|
598 |
+
output_states += (hidden_states,)
|
599 |
+
|
600 |
+
if self.downsamplers is not None:
|
601 |
+
for downsampler in self.downsamplers:
|
602 |
+
hidden_states = downsampler(hidden_states)
|
603 |
+
|
604 |
+
output_states += (hidden_states,)
|
605 |
+
|
606 |
+
return hidden_states, output_states
|
607 |
+
|
608 |
+
|
609 |
+
class CrossAttnUpBlock3D(nn.Module):
|
610 |
+
def __init__(
|
611 |
+
self,
|
612 |
+
in_channels: int,
|
613 |
+
out_channels: int,
|
614 |
+
prev_output_channel: int,
|
615 |
+
temb_channels: int,
|
616 |
+
dropout: float = 0.0,
|
617 |
+
num_layers: int = 1,
|
618 |
+
resnet_eps: float = 1e-6,
|
619 |
+
resnet_time_scale_shift: str = "default",
|
620 |
+
resnet_act_fn: str = "swish",
|
621 |
+
resnet_groups: int = 32,
|
622 |
+
resnet_pre_norm: bool = True,
|
623 |
+
attn_num_head_channels=1,
|
624 |
+
cross_attention_dim=1280,
|
625 |
+
output_scale_factor=1.0,
|
626 |
+
add_upsample=True,
|
627 |
+
dual_cross_attention=False,
|
628 |
+
use_linear_projection=False,
|
629 |
+
only_cross_attention=False,
|
630 |
+
upcast_attention=False,
|
631 |
+
unet_use_cross_frame_attention=False,
|
632 |
+
unet_use_temporal_attention=False,
|
633 |
+
use_inflated_groupnorm=False,
|
634 |
+
use_motion_module=None,
|
635 |
+
motion_module_type=None,
|
636 |
+
motion_module_kwargs=None,
|
637 |
+
add_audio_layer=False,
|
638 |
+
audio_condition_method="cross_attn",
|
639 |
+
custom_audio_layer=False,
|
640 |
+
):
|
641 |
+
super().__init__()
|
642 |
+
resnets = []
|
643 |
+
attentions = []
|
644 |
+
audio_attentions = []
|
645 |
+
motion_modules = []
|
646 |
+
|
647 |
+
self.has_cross_attention = True
|
648 |
+
self.attn_num_head_channels = attn_num_head_channels
|
649 |
+
|
650 |
+
for i in range(num_layers):
|
651 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
652 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
653 |
+
|
654 |
+
resnets.append(
|
655 |
+
ResnetBlock3D(
|
656 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
657 |
+
out_channels=out_channels,
|
658 |
+
temb_channels=temb_channels,
|
659 |
+
eps=resnet_eps,
|
660 |
+
groups=resnet_groups,
|
661 |
+
dropout=dropout,
|
662 |
+
time_embedding_norm=resnet_time_scale_shift,
|
663 |
+
non_linearity=resnet_act_fn,
|
664 |
+
output_scale_factor=output_scale_factor,
|
665 |
+
pre_norm=resnet_pre_norm,
|
666 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
667 |
+
)
|
668 |
+
)
|
669 |
+
if dual_cross_attention:
|
670 |
+
raise NotImplementedError
|
671 |
+
attentions.append(
|
672 |
+
Transformer3DModel(
|
673 |
+
attn_num_head_channels,
|
674 |
+
out_channels // attn_num_head_channels,
|
675 |
+
in_channels=out_channels,
|
676 |
+
num_layers=1,
|
677 |
+
cross_attention_dim=cross_attention_dim,
|
678 |
+
norm_num_groups=resnet_groups,
|
679 |
+
use_linear_projection=use_linear_projection,
|
680 |
+
only_cross_attention=only_cross_attention,
|
681 |
+
upcast_attention=upcast_attention,
|
682 |
+
use_motion_module=use_motion_module,
|
683 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
684 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
685 |
+
add_audio_layer=add_audio_layer,
|
686 |
+
audio_condition_method=audio_condition_method,
|
687 |
+
)
|
688 |
+
)
|
689 |
+
audio_attentions.append(
|
690 |
+
Transformer3DModel(
|
691 |
+
attn_num_head_channels,
|
692 |
+
out_channels // attn_num_head_channels,
|
693 |
+
in_channels=out_channels,
|
694 |
+
num_layers=1,
|
695 |
+
cross_attention_dim=cross_attention_dim,
|
696 |
+
norm_num_groups=resnet_groups,
|
697 |
+
use_linear_projection=use_linear_projection,
|
698 |
+
only_cross_attention=only_cross_attention,
|
699 |
+
upcast_attention=upcast_attention,
|
700 |
+
use_motion_module=use_motion_module,
|
701 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
702 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
703 |
+
add_audio_layer=add_audio_layer,
|
704 |
+
audio_condition_method=audio_condition_method,
|
705 |
+
custom_audio_layer=True,
|
706 |
+
)
|
707 |
+
if custom_audio_layer
|
708 |
+
else None
|
709 |
+
)
|
710 |
+
motion_modules.append(
|
711 |
+
get_motion_module(
|
712 |
+
in_channels=out_channels,
|
713 |
+
motion_module_type=motion_module_type,
|
714 |
+
motion_module_kwargs=motion_module_kwargs,
|
715 |
+
)
|
716 |
+
if use_motion_module
|
717 |
+
else None
|
718 |
+
)
|
719 |
+
|
720 |
+
self.attentions = nn.ModuleList(attentions)
|
721 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
722 |
+
self.resnets = nn.ModuleList(resnets)
|
723 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
724 |
+
|
725 |
+
if add_upsample:
|
726 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
727 |
+
else:
|
728 |
+
self.upsamplers = None
|
729 |
+
|
730 |
+
self.gradient_checkpointing = False
|
731 |
+
|
732 |
+
def forward(
|
733 |
+
self,
|
734 |
+
hidden_states,
|
735 |
+
res_hidden_states_tuple,
|
736 |
+
temb=None,
|
737 |
+
encoder_hidden_states=None,
|
738 |
+
upsample_size=None,
|
739 |
+
attention_mask=None,
|
740 |
+
):
|
741 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
742 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
743 |
+
):
|
744 |
+
# pop res hidden states
|
745 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
746 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
747 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
748 |
+
|
749 |
+
if self.training and self.gradient_checkpointing:
|
750 |
+
|
751 |
+
def create_custom_forward(module, return_dict=None):
|
752 |
+
def custom_forward(*inputs):
|
753 |
+
if return_dict is not None:
|
754 |
+
return module(*inputs, return_dict=return_dict)
|
755 |
+
else:
|
756 |
+
return module(*inputs)
|
757 |
+
|
758 |
+
return custom_forward
|
759 |
+
|
760 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
761 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
762 |
+
create_custom_forward(attn, return_dict=False),
|
763 |
+
hidden_states,
|
764 |
+
encoder_hidden_states,
|
765 |
+
)[0]
|
766 |
+
if motion_module is not None:
|
767 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
768 |
+
create_custom_forward(motion_module),
|
769 |
+
hidden_states.requires_grad_(),
|
770 |
+
temb,
|
771 |
+
encoder_hidden_states,
|
772 |
+
)
|
773 |
+
|
774 |
+
else:
|
775 |
+
hidden_states = resnet(hidden_states, temb)
|
776 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
777 |
+
hidden_states = (
|
778 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
779 |
+
if audio_attn is not None
|
780 |
+
else hidden_states
|
781 |
+
)
|
782 |
+
|
783 |
+
# add motion module
|
784 |
+
hidden_states = (
|
785 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
786 |
+
if motion_module is not None
|
787 |
+
else hidden_states
|
788 |
+
)
|
789 |
+
|
790 |
+
if self.upsamplers is not None:
|
791 |
+
for upsampler in self.upsamplers:
|
792 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
793 |
+
|
794 |
+
return hidden_states
|
795 |
+
|
796 |
+
|
797 |
+
class UpBlock3D(nn.Module):
|
798 |
+
def __init__(
|
799 |
+
self,
|
800 |
+
in_channels: int,
|
801 |
+
prev_output_channel: int,
|
802 |
+
out_channels: int,
|
803 |
+
temb_channels: int,
|
804 |
+
dropout: float = 0.0,
|
805 |
+
num_layers: int = 1,
|
806 |
+
resnet_eps: float = 1e-6,
|
807 |
+
resnet_time_scale_shift: str = "default",
|
808 |
+
resnet_act_fn: str = "swish",
|
809 |
+
resnet_groups: int = 32,
|
810 |
+
resnet_pre_norm: bool = True,
|
811 |
+
output_scale_factor=1.0,
|
812 |
+
add_upsample=True,
|
813 |
+
use_inflated_groupnorm=False,
|
814 |
+
use_motion_module=None,
|
815 |
+
motion_module_type=None,
|
816 |
+
motion_module_kwargs=None,
|
817 |
+
):
|
818 |
+
super().__init__()
|
819 |
+
resnets = []
|
820 |
+
motion_modules = []
|
821 |
+
|
822 |
+
for i in range(num_layers):
|
823 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
824 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
825 |
+
|
826 |
+
resnets.append(
|
827 |
+
ResnetBlock3D(
|
828 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
829 |
+
out_channels=out_channels,
|
830 |
+
temb_channels=temb_channels,
|
831 |
+
eps=resnet_eps,
|
832 |
+
groups=resnet_groups,
|
833 |
+
dropout=dropout,
|
834 |
+
time_embedding_norm=resnet_time_scale_shift,
|
835 |
+
non_linearity=resnet_act_fn,
|
836 |
+
output_scale_factor=output_scale_factor,
|
837 |
+
pre_norm=resnet_pre_norm,
|
838 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
839 |
+
)
|
840 |
+
)
|
841 |
+
motion_modules.append(
|
842 |
+
get_motion_module(
|
843 |
+
in_channels=out_channels,
|
844 |
+
motion_module_type=motion_module_type,
|
845 |
+
motion_module_kwargs=motion_module_kwargs,
|
846 |
+
)
|
847 |
+
if use_motion_module
|
848 |
+
else None
|
849 |
+
)
|
850 |
+
|
851 |
+
self.resnets = nn.ModuleList(resnets)
|
852 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
853 |
+
|
854 |
+
if add_upsample:
|
855 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
856 |
+
else:
|
857 |
+
self.upsamplers = None
|
858 |
+
|
859 |
+
self.gradient_checkpointing = False
|
860 |
+
|
861 |
+
def forward(
|
862 |
+
self,
|
863 |
+
hidden_states,
|
864 |
+
res_hidden_states_tuple,
|
865 |
+
temb=None,
|
866 |
+
upsample_size=None,
|
867 |
+
encoder_hidden_states=None,
|
868 |
+
):
|
869 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
870 |
+
# pop res hidden states
|
871 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
872 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
873 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
874 |
+
|
875 |
+
if self.training and self.gradient_checkpointing:
|
876 |
+
|
877 |
+
def create_custom_forward(module):
|
878 |
+
def custom_forward(*inputs):
|
879 |
+
return module(*inputs)
|
880 |
+
|
881 |
+
return custom_forward
|
882 |
+
|
883 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
884 |
+
if motion_module is not None:
|
885 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
886 |
+
create_custom_forward(motion_module),
|
887 |
+
hidden_states.requires_grad_(),
|
888 |
+
temb,
|
889 |
+
encoder_hidden_states,
|
890 |
+
)
|
891 |
+
else:
|
892 |
+
hidden_states = resnet(hidden_states, temb)
|
893 |
+
hidden_states = (
|
894 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
895 |
+
if motion_module is not None
|
896 |
+
else hidden_states
|
897 |
+
)
|
898 |
+
|
899 |
+
if self.upsamplers is not None:
|
900 |
+
for upsampler in self.upsamplers:
|
901 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
902 |
+
|
903 |
+
return hidden_states
|
latentsync/models/utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
def zero_module(module):
|
16 |
+
# Zero out the parameters of a module and return it.
|
17 |
+
for p in module.parameters():
|
18 |
+
p.detach().zero_()
|
19 |
+
return module
|
latentsync/pipelines/lipsync_pipeline.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
|
13 |
+
from diffusers.utils import is_accelerate_available
|
14 |
+
from packaging import version
|
15 |
+
|
16 |
+
from diffusers.configuration_utils import FrozenDict
|
17 |
+
from diffusers.models import AutoencoderKL
|
18 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
19 |
+
from diffusers.schedulers import (
|
20 |
+
DDIMScheduler,
|
21 |
+
DPMSolverMultistepScheduler,
|
22 |
+
EulerAncestralDiscreteScheduler,
|
23 |
+
EulerDiscreteScheduler,
|
24 |
+
LMSDiscreteScheduler,
|
25 |
+
PNDMScheduler,
|
26 |
+
)
|
27 |
+
from diffusers.utils import deprecate, logging
|
28 |
+
|
29 |
+
from einops import rearrange
|
30 |
+
|
31 |
+
from ..models.unet import UNet3DConditionModel
|
32 |
+
from ..utils.image_processor import ImageProcessor
|
33 |
+
from ..utils.util import read_video, read_audio, write_video
|
34 |
+
from ..whisper.audio2feature import Audio2Feature
|
35 |
+
import tqdm
|
36 |
+
import soundfile as sf
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
class LipsyncPipeline(DiffusionPipeline):
|
42 |
+
_optional_components = []
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
vae: AutoencoderKL,
|
47 |
+
audio_encoder: Audio2Feature,
|
48 |
+
unet: UNet3DConditionModel,
|
49 |
+
scheduler: Union[
|
50 |
+
DDIMScheduler,
|
51 |
+
PNDMScheduler,
|
52 |
+
LMSDiscreteScheduler,
|
53 |
+
EulerDiscreteScheduler,
|
54 |
+
EulerAncestralDiscreteScheduler,
|
55 |
+
DPMSolverMultistepScheduler,
|
56 |
+
],
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
61 |
+
deprecation_message = (
|
62 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
63 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
64 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
65 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
66 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
67 |
+
" file"
|
68 |
+
)
|
69 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
70 |
+
new_config = dict(scheduler.config)
|
71 |
+
new_config["steps_offset"] = 1
|
72 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
73 |
+
|
74 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
75 |
+
deprecation_message = (
|
76 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
77 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
78 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
79 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
80 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
81 |
+
)
|
82 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
83 |
+
new_config = dict(scheduler.config)
|
84 |
+
new_config["clip_sample"] = False
|
85 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
86 |
+
|
87 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
88 |
+
version.parse(unet.config._diffusers_version).base_version
|
89 |
+
) < version.parse("0.9.0.dev0")
|
90 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
91 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
92 |
+
deprecation_message = (
|
93 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
94 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
95 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
96 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
97 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
98 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
99 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
100 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
101 |
+
" the `unet/config.json` file"
|
102 |
+
)
|
103 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
104 |
+
new_config = dict(unet.config)
|
105 |
+
new_config["sample_size"] = 64
|
106 |
+
unet._internal_dict = FrozenDict(new_config)
|
107 |
+
|
108 |
+
self.register_modules(
|
109 |
+
vae=vae,
|
110 |
+
audio_encoder=audio_encoder,
|
111 |
+
unet=unet,
|
112 |
+
scheduler=scheduler,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
116 |
+
|
117 |
+
self.set_progress_bar_config(desc="Steps")
|
118 |
+
|
119 |
+
def enable_vae_slicing(self):
|
120 |
+
self.vae.enable_slicing()
|
121 |
+
|
122 |
+
def disable_vae_slicing(self):
|
123 |
+
self.vae.disable_slicing()
|
124 |
+
|
125 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
126 |
+
if is_accelerate_available():
|
127 |
+
from accelerate import cpu_offload
|
128 |
+
else:
|
129 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
130 |
+
|
131 |
+
device = torch.device(f"cuda:{gpu_id}")
|
132 |
+
|
133 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
134 |
+
if cpu_offloaded_model is not None:
|
135 |
+
cpu_offload(cpu_offloaded_model, device)
|
136 |
+
|
137 |
+
@property
|
138 |
+
def _execution_device(self):
|
139 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
140 |
+
return self.device
|
141 |
+
for module in self.unet.modules():
|
142 |
+
if (
|
143 |
+
hasattr(module, "_hf_hook")
|
144 |
+
and hasattr(module._hf_hook, "execution_device")
|
145 |
+
and module._hf_hook.execution_device is not None
|
146 |
+
):
|
147 |
+
return torch.device(module._hf_hook.execution_device)
|
148 |
+
return self.device
|
149 |
+
|
150 |
+
def decode_latents(self, latents):
|
151 |
+
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
152 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
153 |
+
decoded_latents = self.vae.decode(latents).sample
|
154 |
+
return decoded_latents
|
155 |
+
|
156 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
157 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
158 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
159 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
160 |
+
# and should be between [0, 1]
|
161 |
+
|
162 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
163 |
+
extra_step_kwargs = {}
|
164 |
+
if accepts_eta:
|
165 |
+
extra_step_kwargs["eta"] = eta
|
166 |
+
|
167 |
+
# check if the scheduler accepts generator
|
168 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
169 |
+
if accepts_generator:
|
170 |
+
extra_step_kwargs["generator"] = generator
|
171 |
+
return extra_step_kwargs
|
172 |
+
|
173 |
+
def check_inputs(self, height, width, callback_steps):
|
174 |
+
assert height == width, "Height and width must be equal"
|
175 |
+
|
176 |
+
if height % 8 != 0 or width % 8 != 0:
|
177 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
178 |
+
|
179 |
+
if (callback_steps is None) or (
|
180 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
181 |
+
):
|
182 |
+
raise ValueError(
|
183 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
184 |
+
f" {type(callback_steps)}."
|
185 |
+
)
|
186 |
+
|
187 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, height, width, dtype, device, generator):
|
188 |
+
shape = (
|
189 |
+
batch_size,
|
190 |
+
num_channels_latents,
|
191 |
+
1,
|
192 |
+
height // self.vae_scale_factor,
|
193 |
+
width // self.vae_scale_factor,
|
194 |
+
)
|
195 |
+
rand_device = "cpu" if device.type == "mps" else device
|
196 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
197 |
+
latents = latents.repeat(1, 1, num_frames, 1, 1)
|
198 |
+
|
199 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
200 |
+
latents = latents * self.scheduler.init_noise_sigma
|
201 |
+
return latents
|
202 |
+
|
203 |
+
def prepare_mask_latents(
|
204 |
+
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
|
205 |
+
):
|
206 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
207 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
208 |
+
# and half precision
|
209 |
+
mask = torch.nn.functional.interpolate(
|
210 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
211 |
+
)
|
212 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
213 |
+
|
214 |
+
# encode the mask image into latents space so we can concatenate it to the latents
|
215 |
+
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
216 |
+
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
217 |
+
|
218 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
219 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
220 |
+
mask = mask.to(device=device, dtype=dtype)
|
221 |
+
|
222 |
+
# assume batch size = 1
|
223 |
+
mask = rearrange(mask, "f c h w -> 1 c f h w")
|
224 |
+
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
|
225 |
+
|
226 |
+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
227 |
+
masked_image_latents = (
|
228 |
+
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
229 |
+
)
|
230 |
+
return mask, masked_image_latents
|
231 |
+
|
232 |
+
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
|
233 |
+
images = images.to(device=device, dtype=dtype)
|
234 |
+
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
|
235 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
236 |
+
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
|
237 |
+
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
238 |
+
|
239 |
+
return image_latents
|
240 |
+
|
241 |
+
def set_progress_bar_config(self, **kwargs):
|
242 |
+
if not hasattr(self, "_progress_bar_config"):
|
243 |
+
self._progress_bar_config = {}
|
244 |
+
self._progress_bar_config.update(kwargs)
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
|
248 |
+
# Paste the surrounding pixels back, because we only want to change the mouth region
|
249 |
+
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
|
250 |
+
masks = masks.to(device=device, dtype=weight_dtype)
|
251 |
+
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
|
252 |
+
return combined_pixel_values
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def pixel_values_to_images(pixel_values: torch.Tensor):
|
256 |
+
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
|
257 |
+
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
|
258 |
+
images = (pixel_values * 255).to(torch.uint8)
|
259 |
+
images = images.cpu().numpy()
|
260 |
+
return images
|
261 |
+
|
262 |
+
def affine_transform_video(self, video_path):
|
263 |
+
video_frames = read_video(video_path, use_decord=False)
|
264 |
+
faces = []
|
265 |
+
boxes = []
|
266 |
+
affine_matrices = []
|
267 |
+
print(f"Affine transforming {len(video_frames)} faces...")
|
268 |
+
for frame in tqdm.tqdm(video_frames):
|
269 |
+
face, box, affine_matrix = self.image_processor.affine_transform(frame)
|
270 |
+
faces.append(face)
|
271 |
+
boxes.append(box)
|
272 |
+
affine_matrices.append(affine_matrix)
|
273 |
+
|
274 |
+
faces = torch.stack(faces)
|
275 |
+
return faces, video_frames, boxes, affine_matrices
|
276 |
+
|
277 |
+
def restore_video(self, faces, video_frames, boxes, affine_matrices):
|
278 |
+
video_frames = video_frames[: faces.shape[0]]
|
279 |
+
out_frames = []
|
280 |
+
for index, face in enumerate(faces):
|
281 |
+
x1, y1, x2, y2 = boxes[index]
|
282 |
+
height = int(y2 - y1)
|
283 |
+
width = int(x2 - x1)
|
284 |
+
face = torchvision.transforms.functional.resize(face, size=(height, width), antialias=True)
|
285 |
+
face = rearrange(face, "c h w -> h w c")
|
286 |
+
face = (face / 2 + 0.5).clamp(0, 1)
|
287 |
+
face = (face * 255).to(torch.uint8).cpu().numpy()
|
288 |
+
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
|
289 |
+
out_frames.append(out_frame)
|
290 |
+
return np.stack(out_frames, axis=0)
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def __call__(
|
294 |
+
self,
|
295 |
+
video_path: str,
|
296 |
+
audio_path: str,
|
297 |
+
video_out_path: str,
|
298 |
+
video_mask_path: str = None,
|
299 |
+
num_frames: int = 16,
|
300 |
+
video_fps: int = 25,
|
301 |
+
audio_sample_rate: int = 16000,
|
302 |
+
height: Optional[int] = None,
|
303 |
+
width: Optional[int] = None,
|
304 |
+
num_inference_steps: int = 20,
|
305 |
+
guidance_scale: float = 1.5,
|
306 |
+
weight_dtype: Optional[torch.dtype] = torch.float16,
|
307 |
+
eta: float = 0.0,
|
308 |
+
mask: str = "fix_mask",
|
309 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
310 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
311 |
+
callback_steps: Optional[int] = 1,
|
312 |
+
**kwargs,
|
313 |
+
):
|
314 |
+
is_train = self.unet.training
|
315 |
+
self.unet.eval()
|
316 |
+
|
317 |
+
# 0. Define call parameters
|
318 |
+
batch_size = 1
|
319 |
+
device = self._execution_device
|
320 |
+
self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
|
321 |
+
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
|
322 |
+
|
323 |
+
video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
|
324 |
+
audio_samples = read_audio(audio_path)
|
325 |
+
|
326 |
+
# 1. Default height and width to unet
|
327 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
328 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
329 |
+
|
330 |
+
# 2. Check inputs
|
331 |
+
self.check_inputs(height, width, callback_steps)
|
332 |
+
|
333 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
334 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
335 |
+
# corresponds to doing no classifier free guidance.
|
336 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
337 |
+
|
338 |
+
# 3. set timesteps
|
339 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
340 |
+
timesteps = self.scheduler.timesteps
|
341 |
+
|
342 |
+
# 4. Prepare extra step kwargs.
|
343 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
344 |
+
|
345 |
+
self.video_fps = video_fps
|
346 |
+
|
347 |
+
if self.unet.add_audio_layer:
|
348 |
+
whisper_feature = self.audio_encoder.audio2feat(audio_path)
|
349 |
+
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
|
350 |
+
|
351 |
+
num_inferences = min(len(video_frames), len(whisper_chunks)) // num_frames
|
352 |
+
else:
|
353 |
+
num_inferences = len(video_frames) // num_frames
|
354 |
+
|
355 |
+
synced_video_frames = []
|
356 |
+
masked_video_frames = []
|
357 |
+
|
358 |
+
num_channels_latents = self.vae.config.latent_channels
|
359 |
+
|
360 |
+
# Prepare latent variables
|
361 |
+
all_latents = self.prepare_latents(
|
362 |
+
batch_size,
|
363 |
+
num_frames * num_inferences,
|
364 |
+
num_channels_latents,
|
365 |
+
height,
|
366 |
+
width,
|
367 |
+
weight_dtype,
|
368 |
+
device,
|
369 |
+
generator,
|
370 |
+
)
|
371 |
+
|
372 |
+
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
|
373 |
+
if self.unet.add_audio_layer:
|
374 |
+
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
|
375 |
+
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
|
376 |
+
if do_classifier_free_guidance:
|
377 |
+
empty_audio_embeds = torch.zeros_like(audio_embeds)
|
378 |
+
audio_embeds = torch.cat([empty_audio_embeds, audio_embeds])
|
379 |
+
else:
|
380 |
+
audio_embeds = None
|
381 |
+
inference_video_frames = video_frames[i * num_frames : (i + 1) * num_frames]
|
382 |
+
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
|
383 |
+
pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
|
384 |
+
inference_video_frames, affine_transform=False
|
385 |
+
)
|
386 |
+
|
387 |
+
# 7. Prepare mask latent variables
|
388 |
+
mask_latents, masked_image_latents = self.prepare_mask_latents(
|
389 |
+
masks,
|
390 |
+
masked_pixel_values,
|
391 |
+
height,
|
392 |
+
width,
|
393 |
+
weight_dtype,
|
394 |
+
device,
|
395 |
+
generator,
|
396 |
+
do_classifier_free_guidance,
|
397 |
+
)
|
398 |
+
|
399 |
+
# 8. Prepare image latents
|
400 |
+
image_latents = self.prepare_image_latents(
|
401 |
+
pixel_values,
|
402 |
+
device,
|
403 |
+
weight_dtype,
|
404 |
+
generator,
|
405 |
+
do_classifier_free_guidance,
|
406 |
+
)
|
407 |
+
|
408 |
+
# 9. Denoising loop
|
409 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
410 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
411 |
+
for j, t in enumerate(timesteps):
|
412 |
+
# expand the latents if we are doing classifier free guidance
|
413 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
414 |
+
|
415 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
416 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
417 |
+
latent_model_input = torch.cat(
|
418 |
+
[latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1
|
419 |
+
)
|
420 |
+
|
421 |
+
# predict the noise residual
|
422 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample
|
423 |
+
|
424 |
+
# perform guidance
|
425 |
+
if do_classifier_free_guidance:
|
426 |
+
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
|
427 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
|
428 |
+
|
429 |
+
# compute the previous noisy sample x_t -> x_t-1
|
430 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
431 |
+
|
432 |
+
# call the callback, if provided
|
433 |
+
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
|
434 |
+
progress_bar.update()
|
435 |
+
if callback is not None and j % callback_steps == 0:
|
436 |
+
callback(j, t, latents)
|
437 |
+
|
438 |
+
# Recover the pixel values
|
439 |
+
decoded_latents = self.decode_latents(latents)
|
440 |
+
decoded_latents = self.paste_surrounding_pixels_back(
|
441 |
+
decoded_latents, pixel_values, 1 - masks, device, weight_dtype
|
442 |
+
)
|
443 |
+
synced_video_frames.append(decoded_latents)
|
444 |
+
masked_video_frames.append(masked_pixel_values)
|
445 |
+
|
446 |
+
synced_video_frames = self.restore_video(
|
447 |
+
torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices
|
448 |
+
)
|
449 |
+
masked_video_frames = self.restore_video(
|
450 |
+
torch.cat(masked_video_frames), original_video_frames, boxes, affine_matrices
|
451 |
+
)
|
452 |
+
|
453 |
+
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
|
454 |
+
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
|
455 |
+
|
456 |
+
if is_train:
|
457 |
+
self.unet.train()
|
458 |
+
|
459 |
+
temp_dir = "temp"
|
460 |
+
if os.path.exists(temp_dir):
|
461 |
+
shutil.rmtree(temp_dir)
|
462 |
+
os.makedirs(temp_dir, exist_ok=True)
|
463 |
+
|
464 |
+
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25)
|
465 |
+
# write_video(video_mask_path, masked_video_frames, fps=25)
|
466 |
+
|
467 |
+
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
|
468 |
+
|
469 |
+
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
|
470 |
+
subprocess.run(command, shell=True)
|
latentsync/trepa/__init__.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.nn as nn
|
18 |
+
from einops import rearrange
|
19 |
+
from .third_party.VideoMAEv2.utils import load_videomae_model
|
20 |
+
|
21 |
+
|
22 |
+
class TREPALoss:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
device="cuda",
|
26 |
+
ckpt_path="/mnt/bn/maliva-gen-ai-v2/chunyu.li/checkpoints/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
|
27 |
+
):
|
28 |
+
self.model = load_videomae_model(device, ckpt_path).eval().to(dtype=torch.float16)
|
29 |
+
self.model.requires_grad_(False)
|
30 |
+
self.bce_loss = nn.BCELoss()
|
31 |
+
|
32 |
+
def __call__(self, videos_fake, videos_real, loss_type="mse"):
|
33 |
+
batch_size = videos_fake.shape[0]
|
34 |
+
num_frames = videos_fake.shape[2]
|
35 |
+
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
|
36 |
+
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
|
37 |
+
|
38 |
+
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bilinear")
|
39 |
+
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bilinear")
|
40 |
+
|
41 |
+
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
|
42 |
+
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
|
43 |
+
|
44 |
+
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
|
45 |
+
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
|
46 |
+
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
|
47 |
+
|
48 |
+
feats_fake = self.model.forward_features(videos_fake)
|
49 |
+
feats_real = self.model.forward_features(videos_real)
|
50 |
+
|
51 |
+
feats_fake = F.normalize(feats_fake, p=2, dim=1)
|
52 |
+
feats_real = F.normalize(feats_real, p=2, dim=1)
|
53 |
+
|
54 |
+
return F.mse_loss(feats_fake, feats_real)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
# input shape: (b, c, f, h, w)
|
59 |
+
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
|
60 |
+
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
|
61 |
+
|
62 |
+
trepa_loss = TREPALoss(device="cuda")
|
63 |
+
loss = trepa_loss(videos_fake, videos_real)
|
64 |
+
print(loss)
|
latentsync/trepa/third_party/VideoMAEv2/__init__.py
ADDED
File without changes
|