Skip to content

Commit 2682317

Browse files
committed
includes converting lab file to midi file
1 parent 8b8ae5f commit 2682317

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ This repository has the source codes for the paper "A Bi-Directional Transformer
1212
- librosa >= 0.6.3
1313
- pyyaml >= 3.13
1414
- mir_eval >= 0.5
15+
- pretty_midi >= 0.2.8
1516

1617
## File descriptions
1718
* `audio_dataset.py` : loads data and preprocesses label files to chord labels and mp3 files to constant-q transformation.
@@ -33,7 +34,7 @@ $ python test.py --audio_dir audio_folder --save_dir save_folder --voca False
3334
* save_dir : a forder for saving recognition results (default: './test')
3435
* voca : False means major and minor label type, and True means large vocabulary label type (default: False)
3536

36-
The resulting files are lab files of the form shown below.
37+
The resulting files are lab files of the form shown below and midi files.
3738

3839
<img src="png/example.png">
3940

png/example.png

3.02 KB
Loading

test.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import mir_eval
3+
import pretty_midi as pm
24
from utils import logger
35
from btc_model import *
46
from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths
@@ -12,7 +14,7 @@
1214

1315
# hyperparameters
1416
parser = argparse.ArgumentParser()
15-
parser.add_argument('--voca', default=False, type=lambda x: (str(x).lower() == 'true'))
17+
parser.add_argument('--voca', default=True, type=lambda x: (str(x).lower() == 'true'))
1618
parser.add_argument('--audio_dir', type=str, default='./test')
1719
parser.add_argument('--save_dir', type=str, default='./test')
1820
args = parser.parse_args()
@@ -75,12 +77,12 @@
7577
continue
7678
if prediction[i].item() != prev_chord:
7779
lines.append(
78-
'%.6f %.6f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
80+
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
7981
start_time = time_unit * (n_timestep * t + i)
8082
prev_chord = prediction[i].item()
8183
if t == num_instance - 1 and i + num_pad == n_timestep:
8284
if start_time != time_unit * (n_timestep * t + i):
83-
lines.append('%.6f %.6f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
85+
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
8486
break
8587

8688
# lab file write
@@ -93,3 +95,36 @@
9395

9496
logger.info("label file saved : %s" % save_path)
9597

98+
# lab file to midi file
99+
100+
101+
starts, ends, pitchs = list(), list(), list()
102+
103+
intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
104+
for p in range(12):
105+
for i, (interval, chord) in enumerate(zip(intervals, chords)):
106+
root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
107+
tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
108+
if i == 0:
109+
start_time = interval[0]
110+
label = tmp_label
111+
continue
112+
if tmp_label != label:
113+
if label == 1.0:
114+
starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
115+
start_time = interval[0]
116+
label = tmp_label
117+
if i == (len(intervals) - 1):
118+
if label == 1.0:
119+
starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
120+
121+
midi = pm.PrettyMIDI()
122+
instrument = pm.Instrument(program=0)
123+
124+
for start, end, pitch in zip(starts, ends, pitchs):
125+
pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
126+
instrument.notes.append(pm_note)
127+
128+
midi.instruments.append(instrument)
129+
midi.write(save_path.replace('.lab', '.midi'))
130+

0 commit comments

Comments
 (0)