diff --git a/README.md b/README.md
index 859c5a9..70b579d 100644
--- a/README.md
+++ b/README.md
@@ -57,7 +57,12 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
     cd nunchaku
     git submodule init
     git submodule update
+	# TMPDIR is used to avoid running out of disk space on shared servers
+
 	pip install -e .
+
+	# CFLAGS="-Wno-error" CXXFLAGS="-Wno-error" MAX_JOBS=6 pip install --no-deps --no-build-isolation -e .
+
 	```
 
 ## Usage Example
diff --git a/nunchaku/csrc/flux.h b/nunchaku/csrc/flux.h
index a389612..28f0f54 100644
--- a/nunchaku/csrc/flux.h
+++ b/nunchaku/csrc/flux.h
@@ -39,7 +39,9 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder {
         torch::Tensor temb, 
         torch::Tensor rotary_emb_img, 
         torch::Tensor rotary_emb_context, 
-        torch::Tensor rotary_emb_single) 
+        torch::Tensor rotary_emb_single,
+        const std::vector<torch::Tensor>* controlnet_block_samples = nullptr,
+        const std::vector<torch::Tensor>* controlnet_single_block_samples = nullptr) 
     {
         checkModel();
 
@@ -52,13 +54,31 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder {
         rotary_emb_context = rotary_emb_context.contiguous();
         rotary_emb_single = rotary_emb_single.contiguous();
 
+        // 转换controlnet samples
+        std::vector<Tensor> block_samples;
+        std::vector<Tensor> single_block_samples;
+        
+        if (controlnet_block_samples) {
+            for (const auto& t : *controlnet_block_samples) {
+                block_samples.push_back(from_torch(t));
+            }
+        }
+        
+        if (controlnet_single_block_samples) {
+            for (const auto& t : *controlnet_single_block_samples) {
+                single_block_samples.push_back(from_torch(t));
+            }
+        }
+
         Tensor result = net->forward(
             from_torch(hidden_states),
             from_torch(encoder_hidden_states),
             from_torch(temb),
             from_torch(rotary_emb_img),
             from_torch(rotary_emb_context),
-            from_torch(rotary_emb_single)
+            from_torch(rotary_emb_single),
+            block_samples.empty() ? nullptr : &block_samples,
+            single_block_samples.empty() ? nullptr : &single_block_samples
         );
 
         torch::Tensor output = to_torch(result);
@@ -183,6 +203,7 @@ class QuantizedFluxModel { // : public torch::CustomClassHolder {
                 }
             }
         });
+
     }
 
     void forceFP16Attention(bool enable) {
diff --git a/nunchaku/models/flux.py b/nunchaku/models/flux.py
index 6e04863..f2d455f 100644
--- a/nunchaku/models/flux.py
+++ b/nunchaku/models/flux.py
@@ -27,6 +27,8 @@ def forward(
         encoder_hidden_states: torch.Tensor,
         image_rotary_emb: torch.Tensor,
         joint_attention_kwargs=None,
+        controlnet_block_samples=None,
+        controlnet_single_block_samples=None,
     ):
         batch_size = hidden_states.shape[0]
         txt_tokens = encoder_hidden_states.shape[1]
@@ -43,13 +45,27 @@ def forward(
         assert image_rotary_emb.shape[1] == 1
         assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
         # [bs, tokens, head_dim / 2, 1, 2] (sincos)
-        image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
+        image_rotary_emb = image_rotary_emb.reshape(
+            [batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]
+        )
         rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...]  # .to(self.dtype)
         rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...]  # .to(self.dtype)
         rotary_emb_single = image_rotary_emb  # .to(self.dtype)
 
+        if controlnet_block_samples is None:
+            controlnet_block_samples = []
+        if controlnet_single_block_samples is None:
+            controlnet_single_block_samples = []
+
         hidden_states = self.m.forward(
-            hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
+            hidden_states,
+            encoder_hidden_states,
+            temb,
+            rotary_emb_img,
+            rotary_emb_txt,
+            rotary_emb_single,
+            controlnet_block_samples,
+            controlnet_single_block_samples,
         )
 
         hidden_states = hidden_states.to(original_dtype)
@@ -95,7 +111,10 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
         if Version(diffusers.__version__) >= Version("0.31.0"):
             ids = ids[None, ...]
         n_axes = ids.shape[-1]
-        emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
+        emb = torch.cat(
+            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+            dim=-3,
+        )
         return emb.unsqueeze(1)
 
 
diff --git a/src/FluxModel.cpp b/src/FluxModel.cpp
index 5d66e0c..bb2feea 100644
--- a/src/FluxModel.cpp
+++ b/src/FluxModel.cpp
@@ -616,7 +616,18 @@ FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) {
     }
 }
 
-Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single) {
+
+Tensor FluxModel::forward(
+    Tensor hidden_states,
+    Tensor encoder_hidden_states,
+    Tensor temb,
+    Tensor rotary_emb_img,
+    Tensor rotary_emb_context,
+    Tensor rotary_emb_single,
+    const std::vector<Tensor>* controlnet_block_samples,
+    const std::vector<Tensor>* controlnet_single_block_samples
+) {
+
     const int batch_size = hidden_states.shape[0];
     const Tensor::ScalarType dtype = hidden_states.dtype();
     const Device device = hidden_states.device();
@@ -624,8 +635,26 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
     const int txt_tokens = encoder_hidden_states.shape[1];
     const int img_tokens = hidden_states.shape[1];
 
-    for (auto &&block : transformer_blocks) {
-        std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
+    // for (auto &&block : transformer_blocks) {
+    //     std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
+    // }
+
+        // Joint transformer blocks with controlnet
+    for (size_t i = 0; i < transformer_blocks.size(); i++) {
+        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks[i]->forward(
+            hidden_states, encoder_hidden_states, temb, 
+            rotary_emb_img, rotary_emb_context, 0.0f
+        );
+
+        // Add controlnet residual if available
+        if (controlnet_block_samples && !controlnet_block_samples->empty()) {
+            int interval = std::ceil(
+                float(transformer_blocks.size()) / controlnet_block_samples->size()
+            );
+            if (i % interval == 0) {
+                hidden_states = add(hidden_states, (*controlnet_block_samples)[i / interval]);
+            }
+        }
     }
 
     // txt first, same as diffusers
@@ -637,8 +666,33 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
     hidden_states = concat;
     encoder_hidden_states = {};
 
-    for (auto &&block : single_transformer_blocks) {
-        hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
+    // for (auto &&block : single_transformer_blocks) {
+    //     hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
+    // }
+
+    // Single transformer blocks with controlnet
+    for (size_t i = 0; i < single_transformer_blocks.size(); i++) {
+        hidden_states = single_transformer_blocks[i]->forward(
+            hidden_states, temb, rotary_emb_single
+        );
+
+        // Add controlnet residual if available
+        if (controlnet_single_block_samples && !controlnet_single_block_samples->empty()) {
+            int interval = std::ceil(
+                float(single_transformer_blocks.size()) / controlnet_single_block_samples->size()
+            );
+            if (i % interval == 0) {
+                // Only apply to image tokens, not text tokens
+                Tensor img_states = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
+                img_states = add(img_states, (*controlnet_single_block_samples)[i / interval]);
+                // Copy back
+                for (int b = 0; b < batch_size; b++) {
+                    hidden_states.slice(0, b, b + 1)
+                               .slice(1, txt_tokens, txt_tokens + img_tokens)
+                               .copy_(img_states.slice(0, b, b + 1));
+                }
+            }
+        }
     }
 
     return hidden_states;
diff --git a/src/FluxModel.h b/src/FluxModel.h
index 798db33..6ff8ffb 100644
--- a/src/FluxModel.h
+++ b/src/FluxModel.h
@@ -129,7 +129,10 @@ class JointTransformerBlock : public Module {
 class FluxModel : public Module {
 public:
     FluxModel(Tensor::ScalarType dtype, Device device);
-    Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
+    Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, 
+    Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single,
+    const std::vector<Tensor>* controlnet_block_samples = nullptr,
+    const std::vector<Tensor>* controlnet_single_block_samples = nullptr);
 
 public:
     std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;