I’ve been working on applying a transformer variant to spectrograms for music purposes. The idea was inspired by MMDENSELSTM though I’m sure its been done elsewhere, mainly just doing this as a learning experience.
My supervised model using transformer modules works perfectly fine and actually trains far faster than the convolutional variant and reaches the same level of quality in the same amount of time while being faster at inference, but when trying to use the model in a translation setup applied directly to a mix spectrogram and instrumental spectrogram it seems to either overfit or learn how to cheat, not entirely sure what I’m seeing.
To train, I take the target, pad it on the left by 1, and then slice the last item off and feed that in as the target sequence. The loss is then calculated using the unpadded and unsliced target with the output of the transformer, so each frame is predicting what its transformed counterpart (a mask applied to the input) will be. The whole predicting the mask while being fed in the current instrumental is a little weird but directly predicting raw frequency bins doesn’t offer much help either.
I’m wondering if anyone sees anything off with the following code that could potentially contribute to cheating, something like the mask being done incorrectly for my tensor shape (B,C,H,W → B,W,H in transformer, frequency bins (H) of spectrogram are treated as embedding dimensions for each frame and the frames (W) as the sequence). It could potentially be overfitting, but the speed at which validation loss starts increasing after initially decreasing is a bit odd and makes me think its cheating. One other place that could be an issue is in the actual validation code. Any recommendations would be greatly appreciated here, not doing this for work or anything but feel a strong desire to get this working and understand why it is not working now.
The code for the main transformer module is as follows:
class FrameTransformer(nn.Module): def __init__(self, channels, n_fft=2048, feedforward_dim=512, num_bands=4, num_encoders=1, num_decoders=1, cropsize=1024, bias=False, autoregressive=False, out_activate=ReLU1()): super(FrameTransformer, self).__init__() self.max_bin = n_fft // 2 self.output_bin = n_fft // 2 + 1 self.cropsize = cropsize self.register_buffer('mask', torch.triu(torch.ones(cropsize, cropsize) * float('-inf'), diagonal=1)) self.encoder = nn.ModuleList([FrameTransformerEncoder(channels + i, bins=self.max_bin, num_bands=num_bands, cropsize=cropsize, feedforward_dim=feedforward_dim, bias=bias) for i in range(num_encoders)]) self.decoder = nn.ModuleList([FrameTransformerDecoder(channels + i, channels + num_encoders, bins=self.max_bin, num_bands=num_bands, cropsize=cropsize, feedforward_dim=feedforward_dim, bias=bias) for i in range(num_decoders)]) self.out = nn.Linear(channels + num_decoders, 2, bias=bias) self.activate = out_activate self.register_buffer('indices', torch.arange(cropsize)) self.embedding = nn.Embedding(cropsize, self.max_bin) def embed(self, x): e = self.embedding(self.indices).t() return x + e def __call__(self, src, tgt): mem = self.encode(src) tgt = self.decode(tgt, mem=mem) return tgt def encode(self, src): src = self.embed(src[:, :, :self.max_bin]) for module in self.encoder: t = module(src) src = torch.cat((src, t), dim=1) return src def decode(self, tgt, mem): tgt = self.embed(tgt[:, :, :self.max_bin]) for module in self.decoder: t = module(tgt, mem=mem, mask=self.mask) tgt = torch.cat((tgt, t), dim=1) return F.pad( input=self.activate(self.out(tgt.transpose(1,3)).transpose(1,3)), pad=(0, 0, 0, self.output_bin - self.max_bin), mode='replicate' )
The encoder and decoder are defined below:
class FrameTransformerEncoder(nn.Module): def __init__(self, channels, bins, num_bands=4, cropsize=1024, feedforward_dim=2048, bias=False, dropout=0.1, autoregressive=False): super(FrameTransformerEncoder, self).__init__() self.bins = bins self.cropsize = cropsize self.num_bands = num_bands self.autoregressive = autoregressive self.in_project = nn.Linear(channels, 1, bias=bias) self.encoder = nn.TransformerEncoderLayer(bins, num_bands, feedforward_dim, batch_first=True, norm_first=True, dropout=dropout) def __call__(self, x, mask=None): x = self.in_project(x.transpose(1,3)).squeeze(3) return self.encoder(x).transpose(1,2).unsqueeze(1) class FrameTransformerDecoder(nn.Module): def __init__(self, channels, skip_channels, num_bands=4, cropsize=1024, bins=2048, feedforward_dim=2048, downsamples=0, bias=False, dropout=0.1): super(FrameTransformerDecoder, self).__init__() self.bins = bins self.cropsize = cropsize self.num_bands = num_bands self.in_project = nn.Linear(channels, 1, bias=bias) self.mem_project = nn.Linear(skip_channels, 1, bias=bias) self.decoder = nn.TransformerDecoderLayer(bins, num_bands, feedforward_dim, batch_first=True, norm_first=True, dropout=dropout) def __call__(self, x, mem, mask=None): x = self.in_project(x.transpose(1,3)).squeeze(3) mem = self.mem_project(mem.transpose(1,3)).squeeze(3) return self.decoder(tgt=x, memory=mem, tgt_mask=mask).transpose(1,2).unsqueeze(1)
The forward pass and loss calculation for training:
src = src.to(device) tgt = tgt.to(device) with torch.cuda.amp.autocast_mode.autocast(enabled=grad_scaler is not None): pred = model(src, F.pad(tgt, (1,0))[:, :, :, :-1]) loss = crit(src * pred, tgt)
The forward pass and loss calculation for validation (horrifically inefficient currently):
src = src.to(device) tgt = tgt.to(device) mem = model.encode(src) h = torch.zeros_like(src) for i in tqdm(range(src.shape)): h = model.decode(F.pad(src * h, (1,0))[:, :, :, :-1], mem=mem, idx=i) loss = crit(src * h, tgt)
If any further code is desired let me know, I do have this on github but its a more complex setup. I simplified the architecture to use PyTorch’s built in transformer modules for a sanity check and figure its easier to understand anyways.
Edit: it occurs to me this variant doesn’t have positional encoding… was using relative positional encoding in my last iteration which is why that wasn’t directly used here. Going to add fixed positional encoding and see if that changes anything.
Edit 2: Also occurs to me I’m passing in the mask rather than the expected mask * source…
Update: Validation loss was significantly better with the correct mask added and fixed positional encoding, however it only saw around 10% of my dataset (which is >1TB so takes hours). Will be running some tests while I work to see if this converges.
Update 2: Validation loss dropped yet again for the second epoch on another 10% of my dataset. This seems to point to one of two things being the issue: overfitting (I upped dropout from 0.1 to 0.4), or validation code passing the mask instead of the source times the mask in (which is pretty obviously going to hurt). Going to let it train a third epoch and see what happens, but it would always start getting worse by this point so I’m starting to become hopeful. Or potentially positional encoding, however the original architecture with relative positional encoding had these same training dynamics so I suspect not.
(Updated post with corrected code that at least seems to not completely break)