Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/1115 support flux controlnet v2 #25

Open
wants to merge 74 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
1be9e7a
add task_type support
chuck-ma Nov 10, 2024
c9156c4
add task_type support
chuck-ma Nov 10, 2024
d0eb617
add task_type support
chuck-ma Nov 10, 2024
2dc53fa
add task_type support
chuck-ma Nov 10, 2024
0f1fcb7
add task_type support
chuck-ma Nov 10, 2024
e02d3a8
add task_type support
chuck-ma Nov 10, 2024
ea0b270
add task_type support
chuck-ma Nov 10, 2024
48576b7
add task_type support
chuck-ma Nov 10, 2024
2d11094
add task_type support
chuck-ma Nov 10, 2024
773f0e2
add task_type support
chuck-ma Nov 10, 2024
87f8d14
add task_type support
chuck-ma Nov 10, 2024
bf48058
add task_type support
chuck-ma Nov 10, 2024
2cd369b
add task_type support
chuck-ma Nov 10, 2024
507bf1e
add task_type support
chuck-ma Nov 11, 2024
a04c246
add task_type support
chuck-ma Nov 11, 2024
491a841
add task_type support
chuck-ma Nov 12, 2024
0808008
add task_type support
chuck-ma Nov 12, 2024
736111f
add task_type support
chuck-ma Nov 12, 2024
5c56cf6
add task_type support
chuck-ma Nov 12, 2024
47a90c5
add task_type support
chuck-ma Nov 12, 2024
9e806bc
add task_type support
chuck-ma Nov 12, 2024
a9089ee
add task_type support
chuck-ma Nov 12, 2024
91924b7
add task_type support
chuck-ma Nov 12, 2024
9752542
add task_type support
chuck-ma Nov 12, 2024
31e2324
add task_type support
chuck-ma Nov 12, 2024
34f76a2
add task_type support
chuck-ma Nov 12, 2024
bbff138
add task_type support
chuck-ma Nov 12, 2024
f75a0b7
add task_type support
chuck-ma Nov 13, 2024
c992268
add task_type support
chuck-ma Nov 13, 2024
c6e4487
add task_type support
chuck-ma Nov 13, 2024
afea76b
add task_type support
chuck-ma Nov 13, 2024
ea0298b
add task_type support
chuck-ma Nov 13, 2024
be57816
add task_type support
chuck-ma Nov 13, 2024
196a3d5
add task_type support
chuck-ma Nov 14, 2024
2266a5b
add task_type support
chuck-ma Nov 14, 2024
a4451a1
add task_type support
chuck-ma Nov 14, 2024
a498a17
add task_type support
chuck-ma Nov 14, 2024
6e3cc29
add task_type support
chuck-ma Nov 14, 2024
50e9155
add task_type support
chuck-ma Nov 14, 2024
3dfb6cd
add task_type support
chuck-ma Nov 14, 2024
927fea4
add task_type support
chuck-ma Nov 14, 2024
79a12fe
add task_type support
chuck-ma Nov 14, 2024
3a35063
add task_type support
chuck-ma Nov 14, 2024
e72f784
add task_type support
chuck-ma Nov 15, 2024
b370f95
add task_type support
chuck-ma Nov 15, 2024
8bbe5e6
add task_type support
chuck-ma Nov 15, 2024
c46422b
add task_type support
chuck-ma Nov 15, 2024
cee10b8
add task_type support
chuck-ma Nov 15, 2024
f8eaa50
add task_type support
chuck-ma Nov 15, 2024
1cb2317
add task_type support
chuck-ma Nov 15, 2024
c3082c9
add task_type support
chuck-ma Nov 15, 2024
ff3192c
add task_type support
chuck-ma Nov 15, 2024
5789ebe
add task_type support
chuck-ma Nov 15, 2024
061f2b8
add task_type support
chuck-ma Nov 15, 2024
ef8e81c
add task_type support
chuck-ma Nov 15, 2024
ba5a302
add task_type support
chuck-ma Nov 15, 2024
96d4e0f
add task_type support
chuck-ma Nov 15, 2024
ae096f5
add task_type support
chuck-ma Nov 15, 2024
fe908d8
add task_type support
chuck-ma Nov 15, 2024
f092e40
add task_type support
chuck-ma Nov 15, 2024
1ecaa4e
add task_type support
chuck-ma Nov 15, 2024
78a52db
add task_type support
chuck-ma Nov 15, 2024
439d314
add task_type support
chuck-ma Nov 15, 2024
e3d8466
add task_type support
chuck-ma Nov 15, 2024
a743169
Merge remote-tracking branch 'main-source/main' into feature/1115-sup…
chuck-ma Nov 15, 2024
639ad8a
add task_type support
chuck-ma Nov 15, 2024
20b482e
add task_type support
chuck-ma Nov 15, 2024
805dfbd
add task_type support
chuck-ma Nov 15, 2024
d1acbe5
add task_type support
chuck-ma Nov 15, 2024
3ee9489
add task_type support
chuck-ma Nov 15, 2024
f4aae55
add task_type support
chuck-ma Nov 15, 2024
7e2f3b6
add task_type support
chuck-ma Nov 15, 2024
7fc49dc
add task_type support
chuck-ma Nov 15, 2024
0aa9f94
add task_type support
chuck-ma Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
cd nunchaku
git submodule init
git submodule update
# TMPDIR is used to avoid running out of disk space on shared servers

pip install -e .

# CFLAGS="-Wno-error" CXXFLAGS="-Wno-error" MAX_JOBS=6 pip install --no-deps --no-build-isolation -e .

```

## Usage Example
Expand Down
25 changes: 23 additions & 2 deletions nunchaku/csrc/flux.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder {
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single)
torch::Tensor rotary_emb_single,
const std::vector<torch::Tensor>* controlnet_block_samples = nullptr,
const std::vector<torch::Tensor>* controlnet_single_block_samples = nullptr)
{
checkModel();

Expand All @@ -52,13 +54,31 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder {
rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();

// 转换controlnet samples
std::vector<Tensor> block_samples;
std::vector<Tensor> single_block_samples;

if (controlnet_block_samples) {
for (const auto& t : *controlnet_block_samples) {
block_samples.push_back(from_torch(t));
}
}

if (controlnet_single_block_samples) {
for (const auto& t : *controlnet_single_block_samples) {
single_block_samples.push_back(from_torch(t));
}
}

Tensor result = net->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
from_torch(rotary_emb_single)
from_torch(rotary_emb_single),
block_samples.empty() ? nullptr : &block_samples,
single_block_samples.empty() ? nullptr : &single_block_samples
);

torch::Tensor output = to_torch(result);
Expand Down Expand Up @@ -183,6 +203,7 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder {
}
}
});

}

void forceFP16Attention(bool enable) {
Expand Down
25 changes: 22 additions & 3 deletions nunchaku/models/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def forward(
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
):
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
Expand All @@ -43,13 +45,27 @@ def forward(
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
image_rotary_emb = image_rotary_emb.reshape(
[batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]
)
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype)

if controlnet_block_samples is None:
controlnet_block_samples = []
if controlnet_single_block_samples is None:
controlnet_single_block_samples = []

hidden_states = self.m.forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
)

hidden_states = hidden_states.to(original_dtype)
Expand Down Expand Up @@ -95,7 +111,10 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
if Version(diffusers.__version__) >= Version("0.31.0"):
ids = ids[None, ...]
n_axes = ids.shape[-1]
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)


Expand Down
64 changes: 59 additions & 5 deletions src/FluxModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,45 @@ FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) {
}
}

Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single) {

Tensor FluxModel::forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
const std::vector<Tensor>* controlnet_block_samples,
const std::vector<Tensor>* controlnet_single_block_samples
) {

const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device();

const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];

for (auto &&block : transformer_blocks) {
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
// for (auto &&block : transformer_blocks) {
// std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
// }

// Joint transformer blocks with controlnet
for (size_t i = 0; i < transformer_blocks.size(); i++) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks[i]->forward(
hidden_states, encoder_hidden_states, temb,
rotary_emb_img, rotary_emb_context, 0.0f
);

// Add controlnet residual if available
if (controlnet_block_samples && !controlnet_block_samples->empty()) {
int interval = std::ceil(
float(transformer_blocks.size()) / controlnet_block_samples->size()
);
if (i % interval == 0) {
hidden_states = add(hidden_states, (*controlnet_block_samples)[i / interval]);
}
}
}

// txt first, same as diffusers
Expand All @@ -637,8 +666,33 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
hidden_states = concat;
encoder_hidden_states = {};

for (auto &&block : single_transformer_blocks) {
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
// for (auto &&block : single_transformer_blocks) {
// hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
// }

// Single transformer blocks with controlnet
for (size_t i = 0; i < single_transformer_blocks.size(); i++) {
hidden_states = single_transformer_blocks[i]->forward(
hidden_states, temb, rotary_emb_single
);

// Add controlnet residual if available
if (controlnet_single_block_samples && !controlnet_single_block_samples->empty()) {
int interval = std::ceil(
float(single_transformer_blocks.size()) / controlnet_single_block_samples->size()
);
if (i % interval == 0) {
// Only apply to image tokens, not text tokens
Tensor img_states = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
img_states = add(img_states, (*controlnet_single_block_samples)[i / interval]);
// Copy back
for (int b = 0; b < batch_size; b++) {
hidden_states.slice(0, b, b + 1)
.slice(1, txt_tokens, txt_tokens + img_tokens)
.copy_(img_states.slice(0, b, b + 1));
}
}
}
}

return hidden_states;
Expand Down
5 changes: 4 additions & 1 deletion src/FluxModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ class JointTransformerBlock : public Module {
class FluxModel : public Module {
public:
FluxModel(Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb,
Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single,
const std::vector<Tensor>* controlnet_block_samples = nullptr,
const std::vector<Tensor>* controlnet_single_block_samples = nullptr);

public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
Expand Down