1616#include " vae-enc.h"
1717#include " vae.h"
1818
19+ #include < cctype>
1920#include < cstdio>
2021#include < cstdlib>
2122#include < cstring>
@@ -32,11 +33,6 @@ static void print_usage(const char * prog) {
3233 " --vae <gguf> VAE GGUF file\n\n "
3334 " Reference audio:\n "
3435 " --src-audio <file> Source audio (WAV or MP3, any sample rate)\n\n "
35- " Lego mode (base model only, requires --src-audio):\n "
36- " --lego <track> Generate a track over the source audio context\n "
37- " Track names: vocals, backing_vocals, drums, bass,\n "
38- " guitar, keyboard, percussion, strings, synth,\n "
39- " fx, brass, woodwinds\n\n "
4036 " LoRA:\n "
4137 " --lora <path> LoRA safetensors file or directory\n "
4238 " --lora-scale <float> LoRA scaling factor (default: 1.0)\n\n "
@@ -88,7 +84,6 @@ int main(int argc, char ** argv) {
8884 const char * dit_gguf = NULL ;
8985 const char * vae_gguf = NULL ;
9086 const char * src_audio_path = NULL ;
91- const char * lego_track = NULL ; // --lego <track>
9287 const char * dump_dir = NULL ;
9388 const char * lora_path = NULL ;
9489 float lora_scale = 1 .0f ;
@@ -113,8 +108,6 @@ int main(int argc, char ** argv) {
113108 vae_gguf = argv[++i];
114109 } else if (strcmp (argv[i], " --src-audio" ) == 0 && i + 1 < argc) {
115110 src_audio_path = argv[++i];
116- } else if (strcmp (argv[i], " --lego" ) == 0 && i + 1 < argc) {
117- lego_track = argv[++i];
118111 } else if (strcmp (argv[i], " --lora" ) == 0 && i + 1 < argc) {
119112 lora_path = argv[++i];
120113 } else if (strcmp (argv[i], " --lora-scale" ) == 0 && i + 1 < argc) {
@@ -152,10 +145,6 @@ int main(int argc, char ** argv) {
152145 fprintf (stderr, " [CLI] ERROR: --batch must be 1..9\n " );
153146 return 1 ;
154147 }
155- if (lego_track && !src_audio_path) {
156- fprintf (stderr, " [CLI] ERROR: --lego requires --src-audio\n " );
157- return 1 ;
158- }
159148 if (!dit_gguf) {
160149 fprintf (stderr, " [CLI] ERROR: --dit required\n " );
161150 print_usage (argv[0 ]);
@@ -198,12 +187,6 @@ int main(int argc, char ** argv) {
198187 if (gf_load (&gf, dit_gguf)) {
199188 is_turbo = gf_get_bool (gf, " acestep.is_turbo" );
200189 const void * sl_data = gf_get_data (gf, " silence_latent" );
201- if (lego_track && is_turbo) {
202- fprintf (stderr, " [CLI] ERROR: --lego requires the base DiT model\n " );
203- gf_close (&gf);
204- dit_ggml_free (&model);
205- return 1 ;
206- }
207190 if (sl_data) {
208191 silence_full.resize (15000 * 64 );
209192 memcpy (silence_full.data (), sl_data, 15000 * 64 * sizeof (float ));
@@ -301,11 +284,43 @@ int main(int argc, char ** argv) {
301284 fprintf (stderr, " [Request] ERROR: failed to parse %s, skipping\n " , rpath);
302285 continue ;
303286 }
304- if (req.caption .empty ()) {
287+ if (req.caption .empty () && req. lego . empty () ) {
305288 fprintf (stderr, " [Request] ERROR: caption is empty in %s, skipping\n " , rpath);
306289 continue ;
307290 }
308291
292+ // Lego mode validation (base model only, requires --src-audio)
293+ bool is_lego = !req.lego .empty ();
294+ if (is_lego) {
295+ if (!src_audio_path) {
296+ fprintf (stderr, " [Lego] ERROR: lego requires --src-audio\n " );
297+ return 1 ;
298+ }
299+ if (is_turbo) {
300+ fprintf (stderr, " [Lego] ERROR: lego requires the base DiT model (turbo detected)\n " );
301+ return 1 ;
302+ }
303+ // Reference project: TRACK_NAMES (constants.py)
304+ static const char * allowed[] = {
305+ " vocals" , " backing_vocals" , " drums" , " bass" , " guitar" , " keyboard" ,
306+ " percussion" , " strings" , " synth" , " fx" , " brass" , " woodwinds" ,
307+ };
308+ bool valid = false ;
309+ for (int k = 0 ; k < 12 ; k++) {
310+ if (req.lego == allowed[k]) {
311+ valid = true ;
312+ break ;
313+ }
314+ }
315+ if (!valid) {
316+ fprintf (stderr, " [Lego] ERROR: '%s' is not a valid track name\n " , req.lego .c_str ());
317+ fprintf (stderr,
318+ " Valid: vocals, backing_vocals, drums, bass, guitar, keyboard,\n "
319+ " percussion, strings, synth, fx, brass, woodwinds\n " );
320+ return 1 ;
321+ }
322+ }
323+
309324 // Extract params
310325 const char * caption = req.caption .c_str ();
311326 const char * lyrics = req.lyrics .c_str ();
@@ -424,32 +439,36 @@ int main(int argc, char ** argv) {
424439 }
425440
426441 // 2. Build formatted prompts
427- // Reference project uses opposite-sounding instructions (constants.py):
442+ // Reference project instruction templates (constants.py TASK_INSTRUCTIONS ):
428443 // text2music = "Fill the audio semantic mask..."
429444 // cover = "Generate audio semantic tokens..."
430445 // repaint = "Repaint the mask area..."
431- // lego = "Generate the {track } track based on the audio context:"
446+ // lego = "Generate the {TRACK_NAME } track based on the audio context:"
432447 // Auto-switches to cover when audio_codes are present
433- bool is_cover = have_cover || !codes_vec.empty ();
434-
435- // Lego: build instruction from the track name supplied via --lego <track>
436- char lego_instruction[256 ] = {};
437- const char * instruction;
438- if (lego_track) {
439- snprintf (lego_instruction, sizeof (lego_instruction),
440- " Generate the %s track based on the audio context:" , lego_track);
441- instruction = lego_instruction;
442- fprintf (stderr, " [Lego] track=%s\n " , lego_track);
448+ bool is_cover = have_cover || !codes_vec.empty ();
449+ std::string instruction_str;
450+ if (is_lego) {
451+ // Lego mode: force audio_cover_strength=1.0 so all DiT steps see the source audio
452+ req.audio_cover_strength = 1 .0f ;
453+ fprintf (stderr, " [Lego] track=%s, cover path, strength=1.0\n " , req.lego .c_str ());
454+ // Reference project (task_utils.py:86): track name is UPPERCASE
455+ std::string track_upper = req.lego ;
456+ for (char & c : track_upper) {
457+ c = (char ) toupper ((unsigned char ) c);
458+ }
459+ instruction_str = " Generate the " + track_upper + " track based on the audio context:" ;
460+ } else if (is_repaint) {
461+ instruction_str = " Repaint the mask area based on the given conditions:" ;
462+ } else if (is_cover) {
463+ instruction_str = " Generate audio semantic tokens based on the given conditions:" ;
443464 } else {
444- instruction = is_repaint ? " Repaint the mask area based on the given conditions:" :
445- is_cover ? " Generate audio semantic tokens based on the given conditions:" :
446- " Fill the audio semantic mask based on the given conditions:" ;
465+ instruction_str = " Fill the audio semantic mask based on the given conditions:" ;
447466 }
448467
449468 char metas[512 ];
450469 snprintf (metas, sizeof (metas), " - bpm: %s\n - timesignature: %s\n - keyscale: %s\n - duration: %d seconds\n " , bpm,
451470 timesig, keyscale, (int ) duration);
452- std::string text_str = std::string (" # Instruction\n " ) + instruction + " \n\n " + " # Caption\n " + caption +
471+ std::string text_str = std::string (" # Instruction\n " ) + instruction_str + " \n\n " + " # Caption\n " + caption +
453472 " \n\n " + " # Metas\n " + metas + " <|endoftext|>\n " ;
454473
455474 std::string lyric_str = std::string (" # Languages\n " ) + language + " \n\n # Lyric\n " + lyrics + " <|endoftext|>" ;
@@ -567,7 +586,7 @@ int main(int argc, char ** argv) {
567586 }
568587
569588 // Build context: [T, ctx_ch] = src_latents[64] + chunk_mask[64]
570- // Cover: src = cover_latents, mask = 1.0 everywhere
589+ // Cover/Lego: src = cover_latents, mask = 1.0 everywhere
571590 // Repaint: src = silence in region / cover outside, mask = 1.0 in region / 0.0 outside
572591 // Passthrough: detokenized FSQ codes + silence padding, mask = 1.0
573592 // Text2music: silence only, mask = 1.0
0 commit comments