Skip to content

Commit

Permalink
black code formatting + readme update
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed Jan 13, 2023
1 parent 308fb27 commit be0702b
Show file tree
Hide file tree
Showing 23 changed files with 3,353 additions and 1,691 deletions.
123 changes: 56 additions & 67 deletions README.md

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions miditok/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@


def _changed_class_warning(class_obj):
print(f'\033[93mmiditok warning: {class_obj.__class__.__name__} class has been renamed '
f'{class_obj.__class__.__bases__[0].__name__} and will be removed in future updates, '
f'please consider changing it in your code.\033[0m')
print(
f"\033[93mmiditok warning: {class_obj.__class__.__name__} class has been renamed "
f"{class_obj.__class__.__bases__[0].__name__} and will be removed in future updates, "
f"please consider changing it in your code.\033[0m"
)


class REMIEncoding(REMI):
Expand Down
555 changes: 284 additions & 271 deletions miditok/constants.py

Large diffs are not rendered by default.

467 changes: 316 additions & 151 deletions miditok/cp_word.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion miditok/data_augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .data_augmentation import data_augmentation_dataset, data_augmentation_tokens, data_augmentation_midi
from .data_augmentation import (
data_augmentation_dataset,
data_augmentation_tokens,
data_augmentation_midi,
)

__all__ = [
"data_augmentation_dataset",
Expand Down
267 changes: 180 additions & 87 deletions miditok/data_augmentation/data_augmentation.py

Large diffs are not rendered by default.

328 changes: 225 additions & 103 deletions miditok/midi_like.py

Large diffs are not rendered by default.

639 changes: 441 additions & 198 deletions miditok/midi_tokenizer_base.py

Large diffs are not rendered by default.

425 changes: 293 additions & 132 deletions miditok/mumidi.py

Large diffs are not rendered by default.

329 changes: 231 additions & 98 deletions miditok/octuple.py

Large diffs are not rendered by default.

200 changes: 139 additions & 61 deletions miditok/octuple_mono.py

Large diffs are not rendered by default.

335 changes: 236 additions & 99 deletions miditok/remi.py

Large diffs are not rendered by default.

199 changes: 149 additions & 50 deletions miditok/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@

from .midi_tokenizer_base import MIDITokenizer
from .vocabulary import Vocabulary, Event
from .constants import PITCH_RANGE, NB_VELOCITIES, BEAT_RES, ADDITIONAL_TOKENS, TIME_DIVISION, TEMPO, MIDI_INSTRUMENTS
from .constants import (
PITCH_RANGE,
NB_VELOCITIES,
BEAT_RES,
ADDITIONAL_TOKENS,
TIME_DIVISION,
TEMPO,
MIDI_INSTRUMENTS,
)


class Structured(MIDITokenizer):
Expand Down Expand Up @@ -39,15 +47,33 @@ class Structured(MIDITokenizer):
:param mask: will add a MASK token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""
def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
nb_velocities: int = NB_VELOCITIES, additional_tokens: Dict[str, Union[bool, int]] = ADDITIONAL_TOKENS,
pad: bool = True, sos_eos: bool = False, mask: bool = False, params=None):

def __init__(
self,
pitch_range: range = PITCH_RANGE,
beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
nb_velocities: int = NB_VELOCITIES,
additional_tokens: Dict[str, Union[bool, int]] = ADDITIONAL_TOKENS,
pad: bool = True,
sos_eos: bool = False,
mask: bool = False,
params=None,
):
# No additional tokens
additional_tokens['Chord'] = False # Incompatible additional token
additional_tokens['Rest'] = False
additional_tokens['Tempo'] = False
additional_tokens['TimeSignature'] = False
super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, pad, sos_eos, mask, params=params)
additional_tokens["Chord"] = False # Incompatible additional token
additional_tokens["Rest"] = False
additional_tokens["Tempo"] = False
additional_tokens["TimeSignature"] = False
super().__init__(
pitch_range,
beat_res,
nb_velocities,
additional_tokens,
pad,
sos_eos,
mask,
params=params,
)

def track_to_tokens(self, track: Instrument) -> List[int]:
r"""Converts a track (miditoolkit.Instrument object) into a sequence of tokens
Expand All @@ -59,54 +85,107 @@ def track_to_tokens(self, track: Instrument) -> List[int]:
# notes.sort(key=lambda x: (x.start, x.pitch)) # done in midi_to_tokens
events = []

dur_bins = self.durations_ticks[self.current_midi_metadata['time_division']]
dur_bins = self.durations_ticks[self.current_midi_metadata["time_division"]]

# First time shift if needed
if track.notes[0].start != 0:
if track.notes[0].start > max(dur_bins):
time_shift = track.notes[0].start % self.current_midi_metadata['time_division'] # beat wise
time_shift = (
track.notes[0].start % self.current_midi_metadata["time_division"]
) # beat wise
else:
time_shift = track.notes[0].start
index = np.argmin(np.abs(dur_bins - time_shift))
events.append(Event(type_='TimeShift', value='.'.join(map(str, self.durations[index])), time=0,
desc=f'{time_shift} ticks'))
events.append(
Event(
type_="TimeShift",
value=".".join(map(str, self.durations[index])),
time=0,
desc=f"{time_shift} ticks",
)
)

# Creates the Pitch, Velocity, Duration and Time Shift events
for n, note in enumerate(track.notes[:-1]):
# Pitch
events.append(Event(type_='Pitch', value=note.pitch, time=note.start, desc=note.pitch))
events.append(
Event(type_="Pitch", value=note.pitch, time=note.start, desc=note.pitch)
)
# Velocity
events.append(Event(type_='Velocity', value=note.velocity, time=note.start, desc=f'{note.velocity}'))
events.append(
Event(
type_="Velocity",
value=note.velocity,
time=note.start,
desc=f"{note.velocity}",
)
)
# Duration
duration = note.end - note.start
index = np.argmin(np.abs(dur_bins - duration))
events.append(Event(type_='Duration', value='.'.join(map(str, self.durations[index])), time=note.start,
desc=f'{duration} ticks'))
events.append(
Event(
type_="Duration",
value=".".join(map(str, self.durations[index])),
time=note.start,
desc=f"{duration} ticks",
)
)
# TimeShift
time_shift = track.notes[n + 1].start - note.start
index = np.argmin(np.abs(dur_bins - time_shift))
events.append(Event(type_='TimeShift', time=note.start, desc=f'{time_shift} ticks',
value='.'.join(map(str, self.durations[index])) if time_shift != 0 else '0.0.1'))
events.append(
Event(
type_="TimeShift",
time=note.start,
desc=f"{time_shift} ticks",
value=".".join(map(str, self.durations[index]))
if time_shift != 0
else "0.0.1",
)
)
# Adds the last note
if track.notes[-1].pitch not in self.pitch_range:
if len(events) > 0:
del events[-1]
else:
events.append(Event(type_='Pitch', value=track.notes[-1].pitch, time=track.notes[-1].start,
desc=track.notes[-1].pitch))
events.append(Event(type_='Velocity', value=track.notes[-1].velocity, time=track.notes[-1].start,
desc=f'{track.notes[-1].velocity}'))
events.append(
Event(
type_="Pitch",
value=track.notes[-1].pitch,
time=track.notes[-1].start,
desc=track.notes[-1].pitch,
)
)
events.append(
Event(
type_="Velocity",
value=track.notes[-1].velocity,
time=track.notes[-1].start,
desc=f"{track.notes[-1].velocity}",
)
)
duration = track.notes[-1].end - track.notes[-1].start
index = np.argmin(np.abs(dur_bins - duration))
events.append(Event(type_='Duration', value='.'.join(map(str, self.durations[index])),
time=track.notes[-1].start, desc=f'{duration} ticks'))
events.append(
Event(
type_="Duration",
value=".".join(map(str, self.durations[index])),
time=track.notes[-1].start,
desc=f"{duration} ticks",
)
)

events.sort(key=lambda x: x.time)

return self.events_to_tokens(events)

def tokens_to_track(self, tokens: List[int], time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False)) -> Tuple[Instrument, List[TempoChange]]:
def tokens_to_track(
self,
tokens: List[int],
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
r"""Converts a sequence of tokens into a track object
:param tokens: sequence of tokens to convert
Expand All @@ -116,24 +195,31 @@ def tokens_to_track(self, tokens: List[int], time_division: Optional[int] = TIME
"""
events = self.tokens_to_events(tokens)

name = 'Drums' if program[1] else MIDI_INSTRUMENTS[program[0]]['name']
name = "Drums" if program[1] else MIDI_INSTRUMENTS[program[0]]["name"]
instrument = Instrument(program[0], is_drum=program[1], name=name)
current_tick = 0
count = 0

while count < len(events):
if events[count].type == 'Pitch':
if count + 2 < len(events) and events[count + 1].type == 'Velocity' \
and events[count + 2].type == 'Duration':
if events[count].type == "Pitch":
if (
count + 2 < len(events)
and events[count + 1].type == "Velocity"
and events[count + 2].type == "Duration"
):
pitch = int(events[count].value)
vel = int(events[count + 1].value)
duration = self._token_duration_to_ticks(events[count + 2].value, time_division)
instrument.notes.append(Note(vel, pitch, current_tick, current_tick + duration))
duration = self._token_duration_to_ticks(
events[count + 2].value, time_division
)
instrument.notes.append(
Note(vel, pitch, current_tick, current_tick + duration)
)
count += 3
else:
count += 1
elif events[count].type == 'TimeShift':
beat, pos, res = map(int, events[count].value.split('.'))
elif events[count].type == "TimeShift":
beat, pos, res = map(int, events[count].value.split("."))
current_tick += (beat * res + pos) * time_division // res # time shift
count += 1
else:
Expand All @@ -150,26 +236,32 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> Vocabulary:
:return: the vocabulary object
"""
if sos_eos_tokens is not None:
print('\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, '
'_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m')
print(
"\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, "
"_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m"
)
vocab = Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask)

# PITCH
vocab.add_event(f'Pitch_{i}' for i in self.pitch_range)
vocab.add_event(f"Pitch_{i}" for i in self.pitch_range)

# VELOCITY
vocab.add_event(f'Velocity_{i}' for i in self.velocities)
vocab.add_event(f"Velocity_{i}" for i in self.velocities)

# DURATION
vocab.add_event(f'Duration_{".".join(map(str, duration))}' for duration in self.durations)
vocab.add_event(
f'Duration_{".".join(map(str, duration))}' for duration in self.durations
)

# TIME SHIFT (same as durations)
vocab.add_event('TimeShift_0.0.1') # for a time shift of 0
vocab.add_event(f'TimeShift_{".".join(map(str, duration))}' for duration in self.durations)
vocab.add_event("TimeShift_0.0.1") # for a time shift of 0
vocab.add_event(
f'TimeShift_{".".join(map(str, duration))}' for duration in self.durations
)

# PROGRAM
if self.additional_tokens['Program']:
vocab.add_event(f'Program_{program}' for program in range(-1, 128))
if self.additional_tokens["Program"]:
vocab.add_event(f"Program_{program}" for program in range(-1, 128))

return vocab

Expand All @@ -181,11 +273,18 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]:
:return: the token types transitions dictionary
"""
dic = {'Pitch': ['Velocity'], 'Velocity': ['Duration'], 'Duration': ['TimeShift'], 'TimeShift': ['Pitch']}
dic = {
"Pitch": ["Velocity"],
"Velocity": ["Duration"],
"Duration": ["TimeShift"],
"TimeShift": ["Pitch"],
}
self._add_special_tokens_to_types_graph(dic)
return dic

def token_types_errors(self, tokens: List[int], consider_pad: bool = False) -> float:
def token_types_errors(
self, tokens: List[int], consider_pad: bool = False
) -> float:
r"""Checks if a sequence of tokens is constituted of good token types
successions and returns the error ratio (lower is better).
The Pitch values are also analyzed:
Expand All @@ -206,16 +305,16 @@ def token_types_errors(self, tokens: List[int], consider_pad: bool = False) -> f

def check(tok: int):
nonlocal err, previous_type, current_pitches
token_type, token_value = self.vocab.token_to_event[tok].split('_')
token_type, token_value = self.vocab.token_to_event[tok].split("_")

# Good token type
if token_type in self.tokens_types_graph[previous_type]:
if token_type == 'Pitch':
if token_type == "Pitch":
if int(token_value) in current_pitches:
err += 1 # pitch already played at current position
else:
current_pitches.append(int(token_value))
elif token_type == 'TimeShift':
elif token_type == "TimeShift":
if self._token_duration_to_ticks(token_value, 48) > 0:
current_pitches = [] # moving in time, list reset
# Bad token type
Expand All @@ -228,7 +327,7 @@ def check(tok: int):
check(token)
else:
for token in tokens[1:]:
if previous_type == 'PAD':
if previous_type == "PAD":
break
check(token)
return err / nb_tok_predicted
Loading

0 comments on commit be0702b

Please sign in to comment.