diff --git a/README.md b/README.md index 859c5a9..70b579d 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,12 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat cd nunchaku git submodule init git submodule update + # TMPDIR is used to avoid running out of disk space on shared servers + pip install -e . + + # CFLAGS="-Wno-error" CXXFLAGS="-Wno-error" MAX_JOBS=6 pip install --no-deps --no-build-isolation -e . + ``` ## Usage Example diff --git a/nunchaku/csrc/flux.h b/nunchaku/csrc/flux.h index a389612..28f0f54 100644 --- a/nunchaku/csrc/flux.h +++ b/nunchaku/csrc/flux.h @@ -39,7 +39,9 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder { torch::Tensor temb, torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_context, - torch::Tensor rotary_emb_single) + torch::Tensor rotary_emb_single, + const std::vector<torch::Tensor>* controlnet_block_samples = nullptr, + const std::vector<torch::Tensor>* controlnet_single_block_samples = nullptr) { checkModel(); @@ -52,13 +54,31 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder { rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_single = rotary_emb_single.contiguous(); + // 转换controlnet samples + std::vector<Tensor> block_samples; + std::vector<Tensor> single_block_samples; + + if (controlnet_block_samples) { + for (const auto& t : *controlnet_block_samples) { + block_samples.push_back(from_torch(t)); + } + } + + if (controlnet_single_block_samples) { + for (const auto& t : *controlnet_single_block_samples) { + single_block_samples.push_back(from_torch(t)); + } + } + Tensor result = net->forward( from_torch(hidden_states), from_torch(encoder_hidden_states), from_torch(temb), from_torch(rotary_emb_img), from_torch(rotary_emb_context), - from_torch(rotary_emb_single) + from_torch(rotary_emb_single), + block_samples.empty() ? nullptr : &block_samples, + single_block_samples.empty() ? nullptr : &single_block_samples ); torch::Tensor output = to_torch(result); @@ -183,6 +203,7 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder { } } }); + } void forceFP16Attention(bool enable) { diff --git a/nunchaku/models/flux.py b/nunchaku/models/flux.py index 6e04863..f2d455f 100644 --- a/nunchaku/models/flux.py +++ b/nunchaku/models/flux.py @@ -27,6 +27,8 @@ def forward( encoder_hidden_states: torch.Tensor, image_rotary_emb: torch.Tensor, joint_attention_kwargs=None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, ): batch_size = hidden_states.shape[0] txt_tokens = encoder_hidden_states.shape[1] @@ -43,13 +45,27 @@ def forward( assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens) # [bs, tokens, head_dim / 2, 1, 2] (sincos) - image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]) + image_rotary_emb = image_rotary_emb.reshape( + [batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]] + ) rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype) rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) rotary_emb_single = image_rotary_emb # .to(self.dtype) + if controlnet_block_samples is None: + controlnet_block_samples = [] + if controlnet_single_block_samples is None: + controlnet_single_block_samples = [] + hidden_states = self.m.forward( - hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single + hidden_states, + encoder_hidden_states, + temb, + rotary_emb_img, + rotary_emb_txt, + rotary_emb_single, + controlnet_block_samples, + controlnet_single_block_samples, ) hidden_states = hidden_states.to(original_dtype) @@ -95,7 +111,10 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: if Version(diffusers.__version__) >= Version("0.31.0"): ids = ids[None, ...] n_axes = ids.shape[-1] - emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) return emb.unsqueeze(1) diff --git a/src/FluxModel.cpp b/src/FluxModel.cpp index 5d66e0c..bb2feea 100644 --- a/src/FluxModel.cpp +++ b/src/FluxModel.cpp @@ -616,7 +616,18 @@ FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) { } } -Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single) { + +Tensor FluxModel::forward( + Tensor hidden_states, + Tensor encoder_hidden_states, + Tensor temb, + Tensor rotary_emb_img, + Tensor rotary_emb_context, + Tensor rotary_emb_single, + const std::vector<Tensor>* controlnet_block_samples, + const std::vector<Tensor>* controlnet_single_block_samples +) { + const int batch_size = hidden_states.shape[0]; const Tensor::ScalarType dtype = hidden_states.dtype(); const Device device = hidden_states.device(); @@ -624,8 +635,26 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te const int txt_tokens = encoder_hidden_states.shape[1]; const int img_tokens = hidden_states.shape[1]; - for (auto &&block : transformer_blocks) { - std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); + // for (auto &&block : transformer_blocks) { + // std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); + // } + + // Joint transformer blocks with controlnet + for (size_t i = 0; i < transformer_blocks.size(); i++) { + std::tie(hidden_states, encoder_hidden_states) = transformer_blocks[i]->forward( + hidden_states, encoder_hidden_states, temb, + rotary_emb_img, rotary_emb_context, 0.0f + ); + + // Add controlnet residual if available + if (controlnet_block_samples && !controlnet_block_samples->empty()) { + int interval = std::ceil( + float(transformer_blocks.size()) / controlnet_block_samples->size() + ); + if (i % interval == 0) { + hidden_states = add(hidden_states, (*controlnet_block_samples)[i / interval]); + } + } } // txt first, same as diffusers @@ -637,8 +666,33 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te hidden_states = concat; encoder_hidden_states = {}; - for (auto &&block : single_transformer_blocks) { - hidden_states = block->forward(hidden_states, temb, rotary_emb_single); + // for (auto &&block : single_transformer_blocks) { + // hidden_states = block->forward(hidden_states, temb, rotary_emb_single); + // } + + // Single transformer blocks with controlnet + for (size_t i = 0; i < single_transformer_blocks.size(); i++) { + hidden_states = single_transformer_blocks[i]->forward( + hidden_states, temb, rotary_emb_single + ); + + // Add controlnet residual if available + if (controlnet_single_block_samples && !controlnet_single_block_samples->empty()) { + int interval = std::ceil( + float(single_transformer_blocks.size()) / controlnet_single_block_samples->size() + ); + if (i % interval == 0) { + // Only apply to image tokens, not text tokens + Tensor img_states = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens); + img_states = add(img_states, (*controlnet_single_block_samples)[i / interval]); + // Copy back + for (int b = 0; b < batch_size; b++) { + hidden_states.slice(0, b, b + 1) + .slice(1, txt_tokens, txt_tokens + img_tokens) + .copy_(img_states.slice(0, b, b + 1)); + } + } + } } return hidden_states; diff --git a/src/FluxModel.h b/src/FluxModel.h index 798db33..6ff8ffb 100644 --- a/src/FluxModel.h +++ b/src/FluxModel.h @@ -129,7 +129,10 @@ class JointTransformerBlock : public Module { class FluxModel : public Module { public: FluxModel(Tensor::ScalarType dtype, Device device); - Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single); + Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, + Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, + const std::vector<Tensor>* controlnet_block_samples = nullptr, + const std::vector<Tensor>* controlnet_single_block_samples = nullptr); public: std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;