Skip to content

Commit

Permalink
data augmentation all_offset_combinations argument + fix sep argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathan Fradet committed Jan 26, 2023
1 parent f6225a1 commit bb24512
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 35 deletions.
7 changes: 6 additions & 1 deletion miditok/cp_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class CPWord(MIDITokenizer):
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param sep: will add a SEP token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary. (default: None)
"""

Expand All @@ -66,6 +67,7 @@ def __init__(
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
sep: bool = False,
params=None,
):
# Indexes of additional token types within a compound token
Expand All @@ -88,6 +90,7 @@ def __init__(
pad,
sos_eos,
mask,
sep,
params=params,
)

Expand Down Expand Up @@ -441,7 +444,9 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> List[Vocabulary]:
)

vocab = [
Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep)
Vocabulary(
pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep
)
for _ in range(5)
]

Expand Down
60 changes: 46 additions & 14 deletions miditok/data_augmentation/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def data_augmentation_dataset(
octave_directions: Tuple[bool, bool] = (True, True),
vel_directions: Tuple[bool, bool] = (True, True),
dur_directions: Tuple[bool, bool] = (True, True),
all_offset_combinations: bool = False,
out_path: Union[Path, str] = None,
copy_original_in_new_location: bool = True,
):
Expand All @@ -43,6 +44,9 @@ def data_augmentation_dataset(
as a tuple of two booleans. (default: (True, True))
:param dur_directions: directions to shift the duration augmentation, for up / down
as a tuple of two booleans. (default: (True, True))
:param all_offset_combinations: will perform data augmentation on all the possible
combinations of offsets. If set to False, will perform data augmentation
only based on the original sample.
:param out_path: output path to save the augmented files. Original (non-augmented) MIDIs will be
saved to this location. If none is given, they will be saved in the same location an the
data_path. (default: None)
Expand Down Expand Up @@ -89,7 +93,12 @@ def data_augmentation_dataset(
for track, (_, is_drum) in zip(tokens, programs):
if is_drum: # we dont augment drums
continue
aug = data_augmentation_tokens(np.array(track), tokenizer, *offsets)
aug = data_augmentation_tokens(
np.array(track),
tokenizer,
*offsets,
all_offset_combinations=all_offset_combinations,
)
if len(aug) == 0:
continue
for aug_offsets, seq in aug:
Expand Down Expand Up @@ -145,7 +154,12 @@ def data_augmentation_dataset(
dur_directions,
midi=midi,
)
augmented_midis = data_augmentation_midi(midi, tokenizer, *offsets)
augmented_midis = data_augmentation_midi(
midi,
tokenizer,
*offsets,
all_offset_combinations=all_offset_combinations,
)
for aug_offsets, aug_midi in augmented_midis:
if len(aug_midi.instruments) == 0:
continue
Expand Down Expand Up @@ -292,6 +306,7 @@ def data_augmentation_midi(
pitch_offsets: List[int] = None,
velocity_offsets: List[int] = None,
duration_offsets: List[int] = None,
all_offset_combinations: bool = False,
) -> List[Tuple[Tuple[int, int, int], MidiFile]]:
r"""Perform data augmentation on a MIDI object.
Drum tracks are not augmented, but copied as original in augmented MIDIs.
Expand All @@ -301,6 +316,9 @@ def data_augmentation_midi(
:param pitch_offsets: list of pitch offsets for augmentation.
:param velocity_offsets: list of velocity offsets for augmentation.
:param duration_offsets: list of duration offsets for augmentation.
:param all_offset_combinations: will perform data augmentation on all the possible
combinations of offsets. If set to False, will perform data augmentation
only based on the original sample.
:return: augmented MIDI objects.
"""
augmented = []
Expand Down Expand Up @@ -337,9 +355,12 @@ def augment_vel(
aug_.append(((offsets_[0], offset_, offsets_[2]), midi_aug_))
return aug_

for i in range(len(augmented)):
offsets, midi_aug = augmented[i]
augmented += augment_vel(midi_aug, offsets) # for already augmented midis
if all_offset_combinations:
for i in range(len(augmented)):
offsets, midi_aug = augmented[i]
augmented += augment_vel(
midi_aug, offsets
) # for already augmented midis
augmented += augment_vel(midi, (0, 0, 0)) # for original midi

# TODO Duration augmentation
Expand All @@ -361,9 +382,10 @@ def augment_dur(midi_: MidiFile, offsets_: Tuple[int, int, int]) -> List[Tuple[T
aug_.append(((offsets_[0], offsets_[1], offset_), midi_aug_))
return aug_
for i in range(len(augmented)):
offsets, midi_aug = augmented[i]
augmented += augment_dur(midi_aug, offsets) # for already augmented midis
if all_offset_combinations:
for i in range(len(augmented)):
offsets, midi_aug = augmented[i]
augmented += augment_dur(midi_aug, offsets) # for already augmented midis
augmented += augment_dur(midi, (0, 0, 0)) # for original midi"""

return augmented
Expand All @@ -375,6 +397,7 @@ def data_augmentation_tokens(
pitch_offsets: List[int] = None,
velocity_offsets: List[int] = None,
duration_offsets: List[int] = None,
all_offset_combinations: bool = False,
) -> List[Tuple[Tuple[int, int, int], List[int]]]:
r"""Perform data augmentation on a sequence of tokens, on the pitch dimension.
NOTE: token sequences with BPE will be decoded during the augmentation, this might take some time.
Expand All @@ -389,6 +412,9 @@ def data_augmentation_tokens(
:param pitch_offsets: list of pitch offsets for augmentation.
:param velocity_offsets: list of velocity offsets for augmentation.
:param duration_offsets: list of duration offsets for augmentation.
:param all_offset_combinations: will perform data augmentation on all the possible
combinations of offsets. If set to False, will perform data augmentation
only based on the original sample.
:return: the several data augmentations that have been performed
"""
augmented = []
Expand Down Expand Up @@ -464,9 +490,12 @@ def augment_vel(
aug_.append(((offsets_[0], offset_, offsets_[2]), aug_seq))
return aug_

for i in range(len(augmented)):
offsets, seq_aug = augmented[i]
augmented += augment_vel(seq_aug, offsets) # for already augmented midis
if all_offset_combinations:
for i in range(len(augmented)):
offsets, seq_aug = augmented[i]
augmented += augment_vel(
seq_aug, offsets
) # for already augmented midis
augmented += augment_vel(tokens, (0, 0, 0)) # for original midi

# Duration augmentation
Expand Down Expand Up @@ -502,9 +531,12 @@ def augment_dur(
aug_.append(((offsets_[0], offsets_[1], offset_), aug_seq))
return aug_

for i in range(len(augmented)):
offsets, seq_aug = augmented[i]
augmented += augment_dur(seq_aug, offsets) # for already augmented midis
if all_offset_combinations:
for i in range(len(augmented)):
offsets, seq_aug = augmented[i]
augmented += augment_dur(
seq_aug, offsets
) # for already augmented midis
augmented += augment_dur(tokens, (0, 0, 0)) # for original midi

# Convert all arrays to lists and reapply BPE if necessary
Expand Down
7 changes: 6 additions & 1 deletion miditok/midi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MIDILike(MIDITokenizer):
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param sep: will add a SEP token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""

Expand All @@ -59,6 +60,7 @@ def __init__(
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
sep: bool = False,
params=None,
):
additional_tokens["TimeSignature"] = False # not compatible
Expand All @@ -70,6 +72,7 @@ def __init__(
pad,
sos_eos,
mask,
sep,
params=params,
)

Expand Down Expand Up @@ -321,7 +324,9 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> Vocabulary:
"\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, "
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep)
vocab = Vocabulary(
pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep
)

# NOTE ON
vocab.add_event(f"NoteOn_{i}" for i in self.pitch_range)
Expand Down
19 changes: 15 additions & 4 deletions miditok/mumidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MuMIDI(MIDITokenizer):
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param sep: will add a SEP token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
:param drum_pitch_range: range of used MIDI pitches for drums exclusively
"""
Expand All @@ -60,6 +61,7 @@ def __init__(
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
sep: bool = False,
params=None,
drum_pitch_range: range = DRUM_PITCH_RANGE,
):
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
pad,
sos_eos,
mask,
sep,
True,
params=params,
)
Expand Down Expand Up @@ -204,14 +207,20 @@ def midi_to_tokens(self, midi: MidiFile, *args, **kwargs) -> List[List[int]]:
current_pos = -1
current_track = -2 # because -2 doesn't exist
current_tempo_idx = 0
current_tempo = self.current_midi_metadata["tempo_changes"][current_tempo_idx].tempo
current_tempo = self.current_midi_metadata["tempo_changes"][
current_tempo_idx
].tempo
for note_token in note_tokens:
# (Tempo) update tempo values current_tempo
if self.additional_tokens["Tempo"]:
# If the current tempo is not the last one
if current_tempo_idx + 1 < len(self.current_midi_metadata["tempo_changes"]):
if current_tempo_idx + 1 < len(
self.current_midi_metadata["tempo_changes"]
):
# Will loop over incoming tempo changes
for tempo_change in self.current_midi_metadata["tempo_changes"][current_tempo_idx + 1:]:
for tempo_change in self.current_midi_metadata["tempo_changes"][
current_tempo_idx + 1:
]:
# If this tempo change happened before the current moment
if tempo_change.time <= note_token[0].time:
current_tempo = tempo_change.tempo
Expand Down Expand Up @@ -489,7 +498,9 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> List[Vocabulary]:
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = [
Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep)
Vocabulary(
pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep
)
for _ in range(3)
]

Expand Down
11 changes: 9 additions & 2 deletions miditok/octuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Octuple(MIDITokenizer):
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param sep: will add a SEP token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""

Expand All @@ -53,6 +54,7 @@ def __init__(
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
sep: bool = False,
params=None,
):
additional_tokens["Chord"] = False # Incompatible additional token
Expand All @@ -76,6 +78,7 @@ def __init__(
pad,
sos_eos,
mask,
sep,
True,
params=params,
)
Expand Down Expand Up @@ -253,7 +256,9 @@ def track_to_tokens(self, track: Instrument) -> List[List[Union[Event, int]]]:
self.current_midi_metadata["tempo_changes"]
):
# Will loop over incoming tempo changes
for tempo_change in self.current_midi_metadata["tempo_changes"][current_tempo_idx + 1:]:
for tempo_change in self.current_midi_metadata["tempo_changes"][
current_tempo_idx + 1:
]:
# If this tempo change happened before the current moment
if tempo_change.time <= note.start:
current_tempo = tempo_change.tempo
Expand Down Expand Up @@ -460,7 +465,9 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> List[Vocabulary]:
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = [
Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep)
Vocabulary(
pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep
)
for _ in range(6)
]

Expand Down
7 changes: 6 additions & 1 deletion miditok/octuple_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class OctupleMono(MIDITokenizer):
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param sep: will add a SEP token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""

Expand All @@ -51,6 +52,7 @@ def __init__(
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
sep: bool = False,
params=None,
):
additional_tokens["Chord"] = False # Incompatible additional token
Expand All @@ -74,6 +76,7 @@ def __init__(
pad,
sos_eos,
mask,
sep,
params=params,
)

Expand Down Expand Up @@ -267,7 +270,9 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> List[Vocabulary]:
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = [
Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep)
Vocabulary(
pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep
)
for _ in range(5)
]

Expand Down
7 changes: 6 additions & 1 deletion miditok/remi.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class REMI(MIDITokenizer):
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param sep: will add a SEP token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""

Expand All @@ -51,6 +52,7 @@ def __init__(
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
sep: bool = False,
params=None,
):
additional_tokens["TimeSignature"] = False # not compatible
Expand All @@ -62,6 +64,7 @@ def __init__(
pad,
sos_eos,
mask,
sep,
params=params,
)

Expand Down Expand Up @@ -327,7 +330,9 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> Vocabulary:
"\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, "
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep)
vocab = Vocabulary(
pad=self._pad, sos_eos=self._sos_eos, mask=self._mask, sep=self._sep
)

# BAR
vocab.add_event("Bar_None")
Expand Down
Loading

0 comments on commit bb24512

Please sign in to comment.