Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 7e141b5

Browse files
model comments, docstrings
1 parent f6bea11 commit 7e141b5

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

notepredictor/notepredictor/model.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,19 @@ def forward(self, x):
8181
# return self.net(x)
8282

8383
class ModalityTransformer(nn.Module):
84-
"""Model joint distribution of modalities autoregressively with random permutations"""
84+
"""
85+
Model joint distribution of note modalities (e.g. pitch, time, velocity).
86+
87+
This is an autoregressive Transformer model for the *internal* structure of notes.
88+
It is *not* autoregressive in time, but in modality.
89+
At training time, it executes in parallel over all timesteps and modalities, with
90+
time dependencies provided via the RNN backbone.
91+
92+
At sampling time it is called serially, one modality at a time,
93+
repeatedly at each time step.
94+
95+
Inspired by XLNet: http://arxiv.org/abs/1906.08237
96+
"""
8597
def __init__(self, input_size, hidden_size, heads=4, layers=1):
8698
super().__init__()
8799
self.net = nn.TransformerDecoder(
@@ -95,13 +107,11 @@ def forward(self, ctx, h_ctx, h_tgt):
95107
ctx: list of Tensor[batch x time x input_size], length note_dim-1
96108
these are the embedded ground truth values
97109
h_ctx: Tensor[batch x time x input_size]
98-
(need something to attend to when ctx is empty)
110+
projection of RNN state (need something to attend to when ctx is empty)
99111
h_tgt: list of Tensor[batch x time x input_size], length note_dim
100-
these are projections of the RNN state
112+
these are projections of the RNN state for each target,
113+
which the Transformer will map to distribution parameters.
101114
"""
102-
# h_tgt = list(h_tgt)
103-
# ctx = list(ctx)
104-
105115
# explicitly broadcast
106116
h_ctx, *ctx = torch.broadcast_tensors(h_ctx, *ctx)
107117
h_ctx, *h_tgt = torch.broadcast_tensors(h_ctx, *h_tgt)
@@ -122,6 +132,7 @@ def forward(self, ctx, h_ctx, h_tgt):
122132

123133
# generate a mask
124134
# this is both the target and memory mask
135+
# masking is such that each target can only depend on "previous" context
125136
n = len(h_tgt)
126137
mask = ~tgt.new_ones((n,n), dtype=bool).tril()
127138

@@ -254,7 +265,7 @@ def embeddings(self):
254265

255266
def forward(self, pitches, times, velocities, validation=False):
256267
"""
257-
teacher-forced probabilistic loss and diagnostics for training
268+
teacher-forced probabilistic loss and diagnostics for training.
258269
259270
Args:
260271
pitches: LongTensor[batch, time]
@@ -263,33 +274,41 @@ def forward(self, pitches, times, velocities, validation=False):
263274
"""
264275
batch_size, batch_len = pitches.shape
265276

277+
# embed data to input vectors
266278
pitch_emb = self.pitch_emb(pitches) # batch, time, emb_size
267279
time_emb = self.time_emb(times) # batch, time, emb_size
268280
vel_emb = self.vel_emb(velocities) # batch, time, emb_size
269281

270282
embs = (pitch_emb, time_emb, vel_emb)
271283

284+
# feed to RNN backbone
272285
x = torch.cat(embs, -1)[:,:-1] # skip last time position
273286
## broadcast initial state to batch size
274287
initial_state = tuple(
275288
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
276289
for t in self.initial_state)
277290
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
278291

279-
# fit all note factorizations at once.
292+
# fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch)
280293
# TODO: perm each batch item independently?
294+
# get a random ordering for note modalities:
281295
perm = torch.randperm(self.note_dim)
296+
# chunk RNN state into Transformer inputs
282297
hs = list(self.h_proj(h).chunk(self.note_dim+1, -1))
283298
h_ctx = hs[0]
284299
h_tgt = [hs[i+1] for i in perm]
300+
# embed ground truth values for teacher-forcing
285301
embs = [embs[i][:,1:] for i in perm[:-1]]
302+
# run through Transformer to conditional hidden states
286303
mode_hs = self.xformer(embs, h_ctx, h_tgt)
304+
# permute back to canonical order
287305
mode_hs = [mode_hs[i] for i in perm.argsort()]
288306

307+
# final projections to raw distribution parameters
289308
pitch_params, time_params, vel_params = [
290309
proj(h) for proj,h in zip(self.projections, mode_hs)]
291310

292-
# get likelihoods
311+
# get likelihoods of data for each modality
293312
pitch_logits = F.log_softmax(pitch_params, -1)
294313
pitch_targets = pitches[:,1:,None] #batch, time, 1
295314
pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0]
@@ -309,6 +328,8 @@ def forward(self, pitches, times, velocities, validation=False):
309328
**{'time_'+k:v for k,v in time_result.items()},
310329
**{'velocity_'+k:v for k,v in vel_result.items()}
311330
}
331+
# this just computes some extra diagnostics which are inconvenient to do in the
332+
# training script. should be turned off during training for performance.
312333
if validation:
313334
with torch.no_grad():
314335
r['time_acc_30ms'] = (

0 commit comments

Comments
 (0)