@@ -81,7 +81,19 @@ def forward(self, x):
8181# return self.net(x)
8282
8383class ModalityTransformer (nn .Module ):
84- """Model joint distribution of modalities autoregressively with random permutations"""
84+ """
85+ Model joint distribution of note modalities (e.g. pitch, time, velocity).
86+
87+ This is an autoregressive Transformer model for the *internal* structure of notes.
88+ It is *not* autoregressive in time, but in modality.
89+ At training time, it executes in parallel over all timesteps and modalities, with
90+ time dependencies provided via the RNN backbone.
91+
92+ At sampling time it is called serially, one modality at a time,
93+ repeatedly at each time step.
94+
95+ Inspired by XLNet: http://arxiv.org/abs/1906.08237
96+ """
8597 def __init__ (self , input_size , hidden_size , heads = 4 , layers = 1 ):
8698 super ().__init__ ()
8799 self .net = nn .TransformerDecoder (
@@ -95,13 +107,11 @@ def forward(self, ctx, h_ctx, h_tgt):
95107 ctx: list of Tensor[batch x time x input_size], length note_dim-1
96108 these are the embedded ground truth values
97109 h_ctx: Tensor[batch x time x input_size]
98- (need something to attend to when ctx is empty)
110+ projection of RNN state (need something to attend to when ctx is empty)
99111 h_tgt: list of Tensor[batch x time x input_size], length note_dim
100- these are projections of the RNN state
112+ these are projections of the RNN state for each target,
113+ which the Transformer will map to distribution parameters.
101114 """
102- # h_tgt = list(h_tgt)
103- # ctx = list(ctx)
104-
105115 # explicitly broadcast
106116 h_ctx , * ctx = torch .broadcast_tensors (h_ctx , * ctx )
107117 h_ctx , * h_tgt = torch .broadcast_tensors (h_ctx , * h_tgt )
@@ -122,6 +132,7 @@ def forward(self, ctx, h_ctx, h_tgt):
122132
123133 # generate a mask
124134 # this is both the target and memory mask
135+ # masking is such that each target can only depend on "previous" context
125136 n = len (h_tgt )
126137 mask = ~ tgt .new_ones ((n ,n ), dtype = bool ).tril ()
127138
@@ -254,7 +265,7 @@ def embeddings(self):
254265
255266 def forward (self , pitches , times , velocities , validation = False ):
256267 """
257- teacher-forced probabilistic loss and diagnostics for training
268+ teacher-forced probabilistic loss and diagnostics for training.
258269
259270 Args:
260271 pitches: LongTensor[batch, time]
@@ -263,33 +274,41 @@ def forward(self, pitches, times, velocities, validation=False):
263274 """
264275 batch_size , batch_len = pitches .shape
265276
277+ # embed data to input vectors
266278 pitch_emb = self .pitch_emb (pitches ) # batch, time, emb_size
267279 time_emb = self .time_emb (times ) # batch, time, emb_size
268280 vel_emb = self .vel_emb (velocities ) # batch, time, emb_size
269281
270282 embs = (pitch_emb , time_emb , vel_emb )
271283
284+ # feed to RNN backbone
272285 x = torch .cat (embs , - 1 )[:,:- 1 ] # skip last time position
273286 ## broadcast initial state to batch size
274287 initial_state = tuple (
275288 t .expand (self .rnn .num_layers , x .shape [0 ], - 1 ).contiguous () # 1 x batch x hidden
276289 for t in self .initial_state )
277290 h , _ = self .rnn (x , initial_state ) #batch, time, hidden_size
278291
279- # fit all note factorizations at once.
292+ # fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch)
280293 # TODO: perm each batch item independently?
294+ # get a random ordering for note modalities:
281295 perm = torch .randperm (self .note_dim )
296+ # chunk RNN state into Transformer inputs
282297 hs = list (self .h_proj (h ).chunk (self .note_dim + 1 , - 1 ))
283298 h_ctx = hs [0 ]
284299 h_tgt = [hs [i + 1 ] for i in perm ]
300+ # embed ground truth values for teacher-forcing
285301 embs = [embs [i ][:,1 :] for i in perm [:- 1 ]]
302+ # run through Transformer to conditional hidden states
286303 mode_hs = self .xformer (embs , h_ctx , h_tgt )
304+ # permute back to canonical order
287305 mode_hs = [mode_hs [i ] for i in perm .argsort ()]
288306
307+ # final projections to raw distribution parameters
289308 pitch_params , time_params , vel_params = [
290309 proj (h ) for proj ,h in zip (self .projections , mode_hs )]
291310
292- # get likelihoods
311+ # get likelihoods of data for each modality
293312 pitch_logits = F .log_softmax (pitch_params , - 1 )
294313 pitch_targets = pitches [:,1 :,None ] #batch, time, 1
295314 pitch_log_probs = pitch_logits .gather (- 1 , pitch_targets )[...,0 ]
@@ -309,6 +328,8 @@ def forward(self, pitches, times, velocities, validation=False):
309328 ** {'time_' + k :v for k ,v in time_result .items ()},
310329 ** {'velocity_' + k :v for k ,v in vel_result .items ()}
311330 }
331+ # this just computes some extra diagnostics which are inconvenient to do in the
332+ # training script. should be turned off during training for performance.
312333 if validation :
313334 with torch .no_grad ():
314335 r ['time_acc_30ms' ] = (
0 commit comments