Skip to content

Commit aaaa027

Browse files
authored
Move lego mode from --lego <track> CLI flag to "lego" JSON request field (#21)
* apply requested changes
1 parent 2575f91 commit aaaa027

6 files changed

Lines changed: 108 additions & 44 deletions

File tree

README.md

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,32 @@ EOF
258258
--vae models/vae-BF16.gguf
259259
```
260260

261-
**Lego** (`--lego <track>` + `--src-audio`):
261+
**Lego** (`"lego"` in JSON + `--src-audio`):
262262
generates a new instrument track layered over an existing backing track.
263263
Only the **base model** (`acestep-v15-base`) supports lego mode.
264-
The track name is passed on the CLI; set `audio_cover_strength=1.0` in the
265-
request so the source audio guides all DiT steps.
266264
See `examples/lego.json` and `examples/lego.sh`.
267265

266+
```bash
267+
cat > /tmp/lego.json << 'EOF'
268+
{
269+
"caption": "electric guitar riff, funk guitar, house music, instrumental",
270+
"lyrics": "[Instrumental]",
271+
"lego": "guitar",
272+
"inference_steps": 50,
273+
"guidance_scale": 7.0,
274+
"shift": 1.0
275+
}
276+
EOF
277+
278+
./build/dit-vae \
279+
--src-audio backing-track.wav \
280+
--request /tmp/lego.json \
281+
--text-encoder models/Qwen3-Embedding-0.6B-Q8_0.gguf \
282+
--dit models/acestep-v15-base-Q8_0.gguf \
283+
--vae models/vae-BF16.gguf \
284+
--wav
285+
```
286+
268287
Available track names: `vocals`, `backing_vocals`, `drums`, `bass`, `guitar`,
269288
`keyboard`, `percussion`, `strings`, `synth`, `fx`, `brass`, `woodwinds`.
270289

@@ -295,7 +314,8 @@ the LLM fills them, or a sensible runtime default is applied.
295314
"shift": 3.0,
296315
"audio_cover_strength": 0.5,
297316
"repainting_start": -1,
298-
"repainting_end": -1
317+
"repainting_end": -1,
318+
"lego": ""
299319
}
300320
```
301321

@@ -363,6 +383,15 @@ the DiT regenerates the `[start, end)` time region while preserving everything
363383
else. `-1` on start means 0s (beginning), `-1` on end means source duration
364384
(end). Error if end <= start after resolve. `audio_cover_strength` is ignored.
365385

386+
**`lego`** (string, default `""` = inactive)
387+
Track name for lego mode. Requires `--src-audio` and the **base model**.
388+
Valid names: `vocals`, `backing_vocals`, `drums`, `bass`, `guitar`,
389+
`keyboard`, `percussion`, `strings`, `synth`, `fx`, `brass`, `woodwinds`.
390+
When set, passes the source audio to the DiT as context and builds the
391+
instruction `"Generate the {TRACK} track based on the audio context:"`.
392+
`audio_cover_strength` is forced to 1.0 (all steps see the source audio).
393+
Use `inference_steps=50`, `guidance_scale=7.0`, `shift=1.0` for base model.
394+
366395
### LM sampling (ace-qwen3)
367396

368397
**`lm_temperature`** (float, default `0.85`)

examples/lego.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
{
2-
"caption": "electric guitar riff, funk guitar, house music, instrumental",
3-
"audio_cover_strength": 1.0,
2+
"caption": "",
3+
"lyrics": "[Instrumental]",
4+
"lego": "guitar",
45
"inference_steps": 50,
5-
"guidance_scale": 7.0
6+
"guidance_scale": 7.0,
7+
"shift": 1.0
68
}

examples/lego.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ set -eu
2323
# Step 2: lego guitar on the generated track (base model required)
2424
../build/dit-vae \
2525
--src-audio simple00.wav \
26-
--lego guitar \
2726
--request lego.json \
2827
--text-encoder ../models/Qwen3-Embedding-0.6B-Q8_0.gguf \
2928
--dit ../models/acestep-v15-base-Q8_0.gguf \

src/request.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void request_init(AceRequest * r) {
3434
r->audio_cover_strength = 0.5f;
3535
r->repainting_start = -1.0f;
3636
r->repainting_end = -1.0f;
37+
r->lego = "";
3738
}
3839

3940
// JSON string escape / unescape
@@ -321,6 +322,8 @@ bool request_parse(AceRequest * r, const char * path) {
321322
r->repainting_start = (float) atof(v.c_str());
322323
} else if (k == "repainting_end") {
323324
r->repainting_end = (float) atof(v.c_str());
325+
} else if (k == "lego") {
326+
r->lego = v;
324327
}
325328
}
326329

@@ -356,6 +359,9 @@ bool request_write(const AceRequest * r, const char * path) {
356359
fprintf(f, " \"audio_cover_strength\": %.2f,\n", r->audio_cover_strength);
357360
fprintf(f, " \"repainting_start\": %.1f,\n", r->repainting_start);
358361
fprintf(f, " \"repainting_end\": %.1f,\n", r->repainting_end);
362+
if (!r->lego.empty()) {
363+
fprintf(f, " \"lego\": \"%s\",\n", json_escape(r->lego).c_str());
364+
}
359365
// audio_codes last (no trailing comma)
360366
fprintf(f, " \"audio_codes\": \"%s\"\n", json_escape(r->audio_codes).c_str());
361367
fprintf(f, "}\n");
@@ -380,5 +386,8 @@ void request_dump(const AceRequest * r, FILE * f) {
380386
if (r->repainting_start >= 0.0f || r->repainting_end >= 0.0f) {
381387
fprintf(f, " repaint: start=%.1f end=%.1f\n", r->repainting_start, r->repainting_end);
382388
}
389+
if (!r->lego.empty()) {
390+
fprintf(f, " lego: %s\n", r->lego.c_str());
391+
}
383392
fprintf(f, " audio_codes: %s\n", r->audio_codes.empty() ? "(none)" : "(present)");
384393
}

src/request.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ struct AceRequest {
4949
// -1 on start means 0s, -1 on end means source duration.
5050
float repainting_start; // -1
5151
float repainting_end; // -1
52+
53+
// lego mode (requires --src-audio, base model only)
54+
// Track name from TRACK_NAMES: vocals, backing_vocals, drums, bass, guitar,
55+
// keyboard, percussion, strings, synth, fx, brass, woodwinds.
56+
// Empty = not lego. Sets instruction and forces full-range repaint.
57+
std::string lego; // ""
5258
};
5359

5460
// Initialize all fields to defaults (matches Python GenerationParams defaults)

tools/dit-vae.cpp

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "vae-enc.h"
1717
#include "vae.h"
1818

19+
#include <cctype>
1920
#include <cstdio>
2021
#include <cstdlib>
2122
#include <cstring>
@@ -32,11 +33,6 @@ static void print_usage(const char * prog) {
3233
" --vae <gguf> VAE GGUF file\n\n"
3334
"Reference audio:\n"
3435
" --src-audio <file> Source audio (WAV or MP3, any sample rate)\n\n"
35-
"Lego mode (base model only, requires --src-audio):\n"
36-
" --lego <track> Generate a track over the source audio context\n"
37-
" Track names: vocals, backing_vocals, drums, bass,\n"
38-
" guitar, keyboard, percussion, strings, synth,\n"
39-
" fx, brass, woodwinds\n\n"
4036
"LoRA:\n"
4137
" --lora <path> LoRA safetensors file or directory\n"
4238
" --lora-scale <float> LoRA scaling factor (default: 1.0)\n\n"
@@ -88,7 +84,6 @@ int main(int argc, char ** argv) {
8884
const char * dit_gguf = NULL;
8985
const char * vae_gguf = NULL;
9086
const char * src_audio_path = NULL;
91-
const char * lego_track = NULL; // --lego <track>
9287
const char * dump_dir = NULL;
9388
const char * lora_path = NULL;
9489
float lora_scale = 1.0f;
@@ -113,8 +108,6 @@ int main(int argc, char ** argv) {
113108
vae_gguf = argv[++i];
114109
} else if (strcmp(argv[i], "--src-audio") == 0 && i + 1 < argc) {
115110
src_audio_path = argv[++i];
116-
} else if (strcmp(argv[i], "--lego") == 0 && i + 1 < argc) {
117-
lego_track = argv[++i];
118111
} else if (strcmp(argv[i], "--lora") == 0 && i + 1 < argc) {
119112
lora_path = argv[++i];
120113
} else if (strcmp(argv[i], "--lora-scale") == 0 && i + 1 < argc) {
@@ -152,10 +145,6 @@ int main(int argc, char ** argv) {
152145
fprintf(stderr, "[CLI] ERROR: --batch must be 1..9\n");
153146
return 1;
154147
}
155-
if (lego_track && !src_audio_path) {
156-
fprintf(stderr, "[CLI] ERROR: --lego requires --src-audio\n");
157-
return 1;
158-
}
159148
if (!dit_gguf) {
160149
fprintf(stderr, "[CLI] ERROR: --dit required\n");
161150
print_usage(argv[0]);
@@ -198,12 +187,6 @@ int main(int argc, char ** argv) {
198187
if (gf_load(&gf, dit_gguf)) {
199188
is_turbo = gf_get_bool(gf, "acestep.is_turbo");
200189
const void * sl_data = gf_get_data(gf, "silence_latent");
201-
if (lego_track && is_turbo) {
202-
fprintf(stderr, "[CLI] ERROR: --lego requires the base DiT model\n");
203-
gf_close(&gf);
204-
dit_ggml_free(&model);
205-
return 1;
206-
}
207190
if (sl_data) {
208191
silence_full.resize(15000 * 64);
209192
memcpy(silence_full.data(), sl_data, 15000 * 64 * sizeof(float));
@@ -301,11 +284,43 @@ int main(int argc, char ** argv) {
301284
fprintf(stderr, "[Request] ERROR: failed to parse %s, skipping\n", rpath);
302285
continue;
303286
}
304-
if (req.caption.empty()) {
287+
if (req.caption.empty() && req.lego.empty()) {
305288
fprintf(stderr, "[Request] ERROR: caption is empty in %s, skipping\n", rpath);
306289
continue;
307290
}
308291

292+
// Lego mode validation (base model only, requires --src-audio)
293+
bool is_lego = !req.lego.empty();
294+
if (is_lego) {
295+
if (!src_audio_path) {
296+
fprintf(stderr, "[Lego] ERROR: lego requires --src-audio\n");
297+
return 1;
298+
}
299+
if (is_turbo) {
300+
fprintf(stderr, "[Lego] ERROR: lego requires the base DiT model (turbo detected)\n");
301+
return 1;
302+
}
303+
// Reference project: TRACK_NAMES (constants.py)
304+
static const char * allowed[] = {
305+
"vocals", "backing_vocals", "drums", "bass", "guitar", "keyboard",
306+
"percussion", "strings", "synth", "fx", "brass", "woodwinds",
307+
};
308+
bool valid = false;
309+
for (int k = 0; k < 12; k++) {
310+
if (req.lego == allowed[k]) {
311+
valid = true;
312+
break;
313+
}
314+
}
315+
if (!valid) {
316+
fprintf(stderr, "[Lego] ERROR: '%s' is not a valid track name\n", req.lego.c_str());
317+
fprintf(stderr,
318+
" Valid: vocals, backing_vocals, drums, bass, guitar, keyboard,\n"
319+
" percussion, strings, synth, fx, brass, woodwinds\n");
320+
return 1;
321+
}
322+
}
323+
309324
// Extract params
310325
const char * caption = req.caption.c_str();
311326
const char * lyrics = req.lyrics.c_str();
@@ -424,32 +439,36 @@ int main(int argc, char ** argv) {
424439
}
425440

426441
// 2. Build formatted prompts
427-
// Reference project uses opposite-sounding instructions (constants.py):
442+
// Reference project instruction templates (constants.py TASK_INSTRUCTIONS):
428443
// text2music = "Fill the audio semantic mask..."
429444
// cover = "Generate audio semantic tokens..."
430445
// repaint = "Repaint the mask area..."
431-
// lego = "Generate the {track} track based on the audio context:"
446+
// lego = "Generate the {TRACK_NAME} track based on the audio context:"
432447
// Auto-switches to cover when audio_codes are present
433-
bool is_cover = have_cover || !codes_vec.empty();
434-
435-
// Lego: build instruction from the track name supplied via --lego <track>
436-
char lego_instruction[256] = {};
437-
const char * instruction;
438-
if (lego_track) {
439-
snprintf(lego_instruction, sizeof(lego_instruction),
440-
"Generate the %s track based on the audio context:", lego_track);
441-
instruction = lego_instruction;
442-
fprintf(stderr, "[Lego] track=%s\n", lego_track);
448+
bool is_cover = have_cover || !codes_vec.empty();
449+
std::string instruction_str;
450+
if (is_lego) {
451+
// Lego mode: force audio_cover_strength=1.0 so all DiT steps see the source audio
452+
req.audio_cover_strength = 1.0f;
453+
fprintf(stderr, "[Lego] track=%s, cover path, strength=1.0\n", req.lego.c_str());
454+
// Reference project (task_utils.py:86): track name is UPPERCASE
455+
std::string track_upper = req.lego;
456+
for (char & c : track_upper) {
457+
c = (char) toupper((unsigned char) c);
458+
}
459+
instruction_str = "Generate the " + track_upper + " track based on the audio context:";
460+
} else if (is_repaint) {
461+
instruction_str = "Repaint the mask area based on the given conditions:";
462+
} else if (is_cover) {
463+
instruction_str = "Generate audio semantic tokens based on the given conditions:";
443464
} else {
444-
instruction = is_repaint ? "Repaint the mask area based on the given conditions:" :
445-
is_cover ? "Generate audio semantic tokens based on the given conditions:" :
446-
"Fill the audio semantic mask based on the given conditions:";
465+
instruction_str = "Fill the audio semantic mask based on the given conditions:";
447466
}
448467

449468
char metas[512];
450469
snprintf(metas, sizeof(metas), "- bpm: %s\n- timesignature: %s\n- keyscale: %s\n- duration: %d seconds\n", bpm,
451470
timesig, keyscale, (int) duration);
452-
std::string text_str = std::string("# Instruction\n") + instruction + "\n\n" + "# Caption\n" + caption +
471+
std::string text_str = std::string("# Instruction\n") + instruction_str + "\n\n" + "# Caption\n" + caption +
453472
"\n\n" + "# Metas\n" + metas + "<|endoftext|>\n";
454473

455474
std::string lyric_str = std::string("# Languages\n") + language + "\n\n# Lyric\n" + lyrics + "<|endoftext|>";
@@ -567,7 +586,7 @@ int main(int argc, char ** argv) {
567586
}
568587

569588
// Build context: [T, ctx_ch] = src_latents[64] + chunk_mask[64]
570-
// Cover: src = cover_latents, mask = 1.0 everywhere
589+
// Cover/Lego: src = cover_latents, mask = 1.0 everywhere
571590
// Repaint: src = silence in region / cover outside, mask = 1.0 in region / 0.0 outside
572591
// Passthrough: detokenized FSQ codes + silence padding, mask = 1.0
573592
// Text2music: silence only, mask = 1.0

0 commit comments

Comments
 (0)