@@ -81,7 +81,19 @@ def forward(self, x):
81
81
# return self.net(x)
82
82
83
83
class ModalityTransformer (nn .Module ):
84
- """Model joint distribution of modalities autoregressively with random permutations"""
84
+ """
85
+ Model joint distribution of note modalities (e.g. pitch, time, velocity).
86
+
87
+ This is an autoregressive Transformer model for the *internal* structure of notes.
88
+ It is *not* autoregressive in time, but in modality.
89
+ At training time, it executes in parallel over all timesteps and modalities, with
90
+ time dependencies provided via the RNN backbone.
91
+
92
+ At sampling time it is called serially, one modality at a time,
93
+ repeatedly at each time step.
94
+
95
+ Inspired by XLNet: http://arxiv.org/abs/1906.08237
96
+ """
85
97
def __init__ (self , input_size , hidden_size , heads = 4 , layers = 1 ):
86
98
super ().__init__ ()
87
99
self .net = nn .TransformerDecoder (
@@ -95,13 +107,11 @@ def forward(self, ctx, h_ctx, h_tgt):
95
107
ctx: list of Tensor[batch x time x input_size], length note_dim-1
96
108
these are the embedded ground truth values
97
109
h_ctx: Tensor[batch x time x input_size]
98
- (need something to attend to when ctx is empty)
110
+ projection of RNN state (need something to attend to when ctx is empty)
99
111
h_tgt: list of Tensor[batch x time x input_size], length note_dim
100
- these are projections of the RNN state
112
+ these are projections of the RNN state for each target,
113
+ which the Transformer will map to distribution parameters.
101
114
"""
102
- # h_tgt = list(h_tgt)
103
- # ctx = list(ctx)
104
-
105
115
# explicitly broadcast
106
116
h_ctx , * ctx = torch .broadcast_tensors (h_ctx , * ctx )
107
117
h_ctx , * h_tgt = torch .broadcast_tensors (h_ctx , * h_tgt )
@@ -122,6 +132,7 @@ def forward(self, ctx, h_ctx, h_tgt):
122
132
123
133
# generate a mask
124
134
# this is both the target and memory mask
135
+ # masking is such that each target can only depend on "previous" context
125
136
n = len (h_tgt )
126
137
mask = ~ tgt .new_ones ((n ,n ), dtype = bool ).tril ()
127
138
@@ -254,7 +265,7 @@ def embeddings(self):
254
265
255
266
def forward (self , pitches , times , velocities , validation = False ):
256
267
"""
257
- teacher-forced probabilistic loss and diagnostics for training
268
+ teacher-forced probabilistic loss and diagnostics for training.
258
269
259
270
Args:
260
271
pitches: LongTensor[batch, time]
@@ -263,33 +274,41 @@ def forward(self, pitches, times, velocities, validation=False):
263
274
"""
264
275
batch_size , batch_len = pitches .shape
265
276
277
+ # embed data to input vectors
266
278
pitch_emb = self .pitch_emb (pitches ) # batch, time, emb_size
267
279
time_emb = self .time_emb (times ) # batch, time, emb_size
268
280
vel_emb = self .vel_emb (velocities ) # batch, time, emb_size
269
281
270
282
embs = (pitch_emb , time_emb , vel_emb )
271
283
284
+ # feed to RNN backbone
272
285
x = torch .cat (embs , - 1 )[:,:- 1 ] # skip last time position
273
286
## broadcast initial state to batch size
274
287
initial_state = tuple (
275
288
t .expand (self .rnn .num_layers , x .shape [0 ], - 1 ).contiguous () # 1 x batch x hidden
276
289
for t in self .initial_state )
277
290
h , _ = self .rnn (x , initial_state ) #batch, time, hidden_size
278
291
279
- # fit all note factorizations at once.
292
+ # fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch)
280
293
# TODO: perm each batch item independently?
294
+ # get a random ordering for note modalities:
281
295
perm = torch .randperm (self .note_dim )
296
+ # chunk RNN state into Transformer inputs
282
297
hs = list (self .h_proj (h ).chunk (self .note_dim + 1 , - 1 ))
283
298
h_ctx = hs [0 ]
284
299
h_tgt = [hs [i + 1 ] for i in perm ]
300
+ # embed ground truth values for teacher-forcing
285
301
embs = [embs [i ][:,1 :] for i in perm [:- 1 ]]
302
+ # run through Transformer to conditional hidden states
286
303
mode_hs = self .xformer (embs , h_ctx , h_tgt )
304
+ # permute back to canonical order
287
305
mode_hs = [mode_hs [i ] for i in perm .argsort ()]
288
306
307
+ # final projections to raw distribution parameters
289
308
pitch_params , time_params , vel_params = [
290
309
proj (h ) for proj ,h in zip (self .projections , mode_hs )]
291
310
292
- # get likelihoods
311
+ # get likelihoods of data for each modality
293
312
pitch_logits = F .log_softmax (pitch_params , - 1 )
294
313
pitch_targets = pitches [:,1 :,None ] #batch, time, 1
295
314
pitch_log_probs = pitch_logits .gather (- 1 , pitch_targets )[...,0 ]
@@ -309,6 +328,8 @@ def forward(self, pitches, times, velocities, validation=False):
309
328
** {'time_' + k :v for k ,v in time_result .items ()},
310
329
** {'velocity_' + k :v for k ,v in vel_result .items ()}
311
330
}
331
+ # this just computes some extra diagnostics which are inconvenient to do in the
332
+ # training script. should be turned off during training for performance.
312
333
if validation :
313
334
with torch .no_grad ():
314
335
r ['time_acc_30ms' ] = (
0 commit comments