|
1 | 1 | import os
|
| 2 | +import mir_eval |
| 3 | +import pretty_midi as pm |
2 | 4 | from utils import logger
|
3 | 5 | from btc_model import *
|
4 | 6 | from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths
|
|
12 | 14 |
|
13 | 15 | # hyperparameters
|
14 | 16 | parser = argparse.ArgumentParser()
|
15 |
| -parser.add_argument('--voca', default=False, type=lambda x: (str(x).lower() == 'true')) |
| 17 | +parser.add_argument('--voca', default=True, type=lambda x: (str(x).lower() == 'true')) |
16 | 18 | parser.add_argument('--audio_dir', type=str, default='./test')
|
17 | 19 | parser.add_argument('--save_dir', type=str, default='./test')
|
18 | 20 | args = parser.parse_args()
|
|
75 | 77 | continue
|
76 | 78 | if prediction[i].item() != prev_chord:
|
77 | 79 | lines.append(
|
78 |
| - '%.6f %.6f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord])) |
| 80 | + '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord])) |
79 | 81 | start_time = time_unit * (n_timestep * t + i)
|
80 | 82 | prev_chord = prediction[i].item()
|
81 | 83 | if t == num_instance - 1 and i + num_pad == n_timestep:
|
82 | 84 | if start_time != time_unit * (n_timestep * t + i):
|
83 |
| - lines.append('%.6f %.6f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord])) |
| 85 | + lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord])) |
84 | 86 | break
|
85 | 87 |
|
86 | 88 | # lab file write
|
|
93 | 95 |
|
94 | 96 | logger.info("label file saved : %s" % save_path)
|
95 | 97 |
|
| 98 | + # lab file to midi file |
| 99 | + |
| 100 | + |
| 101 | + starts, ends, pitchs = list(), list(), list() |
| 102 | + |
| 103 | + intervals, chords = mir_eval.io.load_labeled_intervals(save_path) |
| 104 | + for p in range(12): |
| 105 | + for i, (interval, chord) in enumerate(zip(intervals, chords)): |
| 106 | + root_num, relative_bitmap, _ = mir_eval.chord.encode(chord) |
| 107 | + tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p] |
| 108 | + if i == 0: |
| 109 | + start_time = interval[0] |
| 110 | + label = tmp_label |
| 111 | + continue |
| 112 | + if tmp_label != label: |
| 113 | + if label == 1.0: |
| 114 | + starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48) |
| 115 | + start_time = interval[0] |
| 116 | + label = tmp_label |
| 117 | + if i == (len(intervals) - 1): |
| 118 | + if label == 1.0: |
| 119 | + starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48) |
| 120 | + |
| 121 | + midi = pm.PrettyMIDI() |
| 122 | + instrument = pm.Instrument(program=0) |
| 123 | + |
| 124 | + for start, end, pitch in zip(starts, ends, pitchs): |
| 125 | + pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end) |
| 126 | + instrument.notes.append(pm_note) |
| 127 | + |
| 128 | + midi.instruments.append(instrument) |
| 129 | + midi.write(save_path.replace('.lab', '.midi')) |
| 130 | + |
0 commit comments