Skip to content

Commit 9e28be6

Browse files
authored
feat: add chroma radiance support (#910)
* add chroma radiance support * fix ci * simply generate_init_latent * workaround: avoid ggml cuda error * format code * add chroma radiance doc
1 parent 062490a commit 9e28be6

File tree

10 files changed

+630
-225
lines changed

10 files changed

+630
-225
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ API and command-line option may change frequently.***
3535
- Image Models
3636
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
3737
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
38-
- [some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
38+
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
3939
- [SD3/SD3.5](./docs/sd3.md)
4040
- [Flux-dev/Flux-schnell](./docs/flux.md)
4141
- [Chroma](./docs/chroma.md)
42+
- [Chroma1-Radiance](./docs/chroma_radiance.md)
4243
- [Qwen Image](./docs/qwen_image.md)
4344
- Image Edit Models
4445
- [FLUX.1-Kontext-dev](./docs/kontext.md)

assets/flux/chroma1-radiance.png

477 KB
Loading

docs/chroma_radiance.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# How to Use
2+
3+
## Download weights
4+
5+
- Download Chroma1-Radiance
6+
- safetensors: https://huggingface.co/lodestones/Chroma1-Radiance/tree/main
7+
- gguf: https://huggingface.co/silveroxides/Chroma1-Radiance-GGUF/tree/main
8+
9+
- Download t5xxl
10+
- safetensors: https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors
11+
12+
## Examples
13+
14+
```
15+
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Chroma1-Radiance-v0.4-Q8_0.gguf --t5xxl ..\..\ComfyUI\models\clip\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma radiance cpp'" --cfg-scale 4.0 --sampling-method euler -v
16+
```
17+
18+
<img alt="Chroma1-Radiance" src="../assets/flux/chroma1-radiance.png" />
19+
20+
21+

flux.hpp

Lines changed: 459 additions & 111 deletions
Large diffs are not rendered by default.

ggml_extend.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
954954
if (scale != 1.f) {
955955
x = ggml_scale(ctx, x, scale);
956956
}
957-
x = ggml_mul_mat(ctx, w, x);
957+
if (x->ne[2] * x->ne[3] > 1024) {
958+
// workaround: avoid ggml cuda error
959+
int64_t ne2 = x->ne[2];
960+
int64_t ne3 = x->ne[3];
961+
x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]);
962+
x = ggml_mul_mat(ctx, w, x);
963+
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
964+
} else {
965+
x = ggml_mul_mat(ctx, w, x);
966+
}
958967
if (force_prec_f32) {
959968
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
960969
}

model.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,6 @@ bool ModelLoader::model_is_unet() {
17781778

17791779
SDVersion ModelLoader::get_sd_version() {
17801780
TensorStorage token_embedding_weight, input_block_weight;
1781-
bool input_block_checked = false;
17821781

17831782
bool has_multiple_encoders = false;
17841783
bool is_unet = false;
@@ -1791,12 +1790,12 @@ SDVersion ModelLoader::get_sd_version() {
17911790
bool has_middle_block_1 = false;
17921791

17931792
for (auto& tensor_storage : tensor_storages) {
1794-
if (!(is_xl || is_flux)) {
1793+
if (!(is_xl)) {
17951794
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
17961795
is_flux = true;
1797-
if (input_block_checked) {
1798-
break;
1799-
}
1796+
}
1797+
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
1798+
return VERSION_CHROMA_RADIANCE;
18001799
}
18011800
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
18021801
return VERSION_SD3;
@@ -1813,22 +1812,19 @@ SDVersion ModelLoader::get_sd_version() {
18131812
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
18141813
has_img_emb = true;
18151814
}
1816-
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
1815+
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos ||
1816+
tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
18171817
is_unet = true;
18181818
if (has_multiple_encoders) {
18191819
is_xl = true;
1820-
if (input_block_checked) {
1821-
break;
1822-
}
18231820
}
18241821
}
1825-
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
1822+
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos ||
1823+
tensor_storage.name.find("cond_stage_model.1") != std::string::npos ||
1824+
tensor_storage.name.find("te.1") != std::string::npos) {
18261825
has_multiple_encoders = true;
18271826
if (is_unet) {
18281827
is_xl = true;
1829-
if (input_block_checked) {
1830-
break;
1831-
}
18321828
}
18331829
}
18341830
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
@@ -1848,12 +1844,10 @@ SDVersion ModelLoader::get_sd_version() {
18481844
token_embedding_weight = tensor_storage;
18491845
// break;
18501846
}
1851-
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
1852-
input_block_weight = tensor_storage;
1853-
input_block_checked = true;
1854-
if (is_flux) {
1855-
break;
1856-
}
1847+
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" ||
1848+
tensor_storage.name == "model.diffusion_model.img_in.weight" ||
1849+
tensor_storage.name == "unet.conv_in.weight") {
1850+
input_block_weight = tensor_storage;
18571851
}
18581852
}
18591853
if (is_wan) {

model.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum SDVersion {
3636
VERSION_FLUX_FILL,
3737
VERSION_FLUX_CONTROLS,
3838
VERSION_FLEX_2,
39+
VERSION_CHROMA_RADIANCE,
3940
VERSION_WAN2,
4041
VERSION_WAN2_2_I2V,
4142
VERSION_WAN2_2_TI2V,
@@ -72,7 +73,11 @@ static inline bool sd_version_is_sd3(SDVersion version) {
7273
}
7374

7475
static inline bool sd_version_is_flux(SDVersion version) {
75-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
76+
if (version == VERSION_FLUX ||
77+
version == VERSION_FLUX_FILL ||
78+
version == VERSION_FLUX_CONTROLS ||
79+
version == VERSION_FLEX_2 ||
80+
version == VERSION_CHROMA_RADIANCE) {
7681
return true;
7782
}
7883
return false;

qwen_image.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ namespace Qwen {
649649

650650
static void load_from_file_and_test(const std::string& file_path) {
651651
// cuda q8: pass
652-
// cuda q8 fa: nan
652+
// cuda q8 fa: pass
653653
// ggml_backend_t backend = ggml_backend_cuda_init(0);
654654
ggml_backend_t backend = ggml_backend_cpu_init();
655655
ggml_type model_data_type = GGML_TYPE_Q8_0;

0 commit comments

Comments
 (0)