Need another set of eyes on weird transformer setup

Hey all`

I’ve been working on applying a transformer variant to spectrograms for music purposes. The idea was inspired by MMDENSELSTM though I’m sure its been done elsewhere, mainly just doing this as a learning experience.

My supervised model using transformer modules works perfectly fine and actually trains far faster than the convolutional variant and reaches the same level of quality in the same amount of time while being faster at inference, but when trying to use the model in a translation setup applied directly to a mix spectrogram and instrumental spectrogram it seems to either overfit or learn how to cheat, not entirely sure what I’m seeing.

To train, I take the target, pad it on the left by 1, and then slice the last item off and feed that in as the target sequence. The loss is then calculated using the unpadded and unsliced target with the output of the transformer, so each frame is predicting what its transformed counterpart (a mask applied to the input) will be. The whole predicting the mask while being fed in the current instrumental is a little weird but directly predicting raw frequency bins doesn’t offer much help either.

I’m wondering if anyone sees anything off with the following code that could potentially contribute to cheating, something like the mask being done incorrectly for my tensor shape (B,C,H,W → B,W,H in transformer, frequency bins (H) of spectrogram are treated as embedding dimensions for each frame and the frames (W) as the sequence). It could potentially be overfitting, but the speed at which validation loss starts increasing after initially decreasing is a bit odd and makes me think its cheating. One other place that could be an issue is in the actual validation code. Any recommendations would be greatly appreciated here, not doing this for work or anything but feel a strong desire to get this working and understand why it is not working now.

The code for the main transformer module is as follows:

class FrameTransformer(nn.Module):
    def __init__(self, channels, n_fft=2048, feedforward_dim=512, num_bands=4, num_encoders=1, num_decoders=1, cropsize=1024, bias=False, autoregressive=False, out_activate=ReLU1()):
        super(FrameTransformer, self).__init__()
        
        self.max_bin = n_fft // 2
        self.output_bin = n_fft // 2 + 1
        self.cropsize = cropsize
        
        self.register_buffer('mask', torch.triu(torch.ones(cropsize, cropsize) * float('-inf'), diagonal=1))
        self.encoder = nn.ModuleList([FrameTransformerEncoder(channels + i, bins=self.max_bin, num_bands=num_bands, cropsize=cropsize, feedforward_dim=feedforward_dim, bias=bias) for i in range(num_encoders)])
        self.decoder = nn.ModuleList([FrameTransformerDecoder(channels + i, channels + num_encoders, bins=self.max_bin, num_bands=num_bands, cropsize=cropsize, feedforward_dim=feedforward_dim, bias=bias) for i in range(num_decoders)])
        self.out = nn.Linear(channels + num_decoders, 2, bias=bias)
        self.activate = out_activate

        self.register_buffer('indices', torch.arange(cropsize))
        self.embedding = nn.Embedding(cropsize, self.max_bin)

    def embed(self, x):
        e = self.embedding(self.indices).t()
        return x + e

    def __call__(self, src, tgt):
        mem = self.encode(src)
        tgt = self.decode(tgt, mem=mem)
        return tgt

    def encode(self, src):
        src = self.embed(src[:, :, :self.max_bin])

        for module in self.encoder:
            t = module(src)
            src = torch.cat((src, t), dim=1)

        return src

    def decode(self, tgt, mem):
        tgt = self.embed(tgt[:, :, :self.max_bin])

        for module in self.decoder:
            t = module(tgt, mem=mem, mask=self.mask)
            tgt = torch.cat((tgt, t), dim=1)

        return F.pad(
            input=self.activate(self.out(tgt.transpose(1,3)).transpose(1,3)),
            pad=(0, 0, 0, self.output_bin - self.max_bin),
            mode='replicate'
        )

The encoder and decoder are defined below:

class FrameTransformerEncoder(nn.Module):
    def __init__(self, channels, bins, num_bands=4, cropsize=1024, feedforward_dim=2048, bias=False, dropout=0.1, autoregressive=False):
        super(FrameTransformerEncoder, self).__init__()

        self.bins = bins
        self.cropsize = cropsize
        self.num_bands = num_bands
        self.autoregressive = autoregressive
        self.in_project = nn.Linear(channels, 1, bias=bias)
        self.encoder = nn.TransformerEncoderLayer(bins, num_bands, feedforward_dim, batch_first=True, norm_first=True, dropout=dropout)

    def __call__(self, x, mask=None):
        x = self.in_project(x.transpose(1,3)).squeeze(3)
        return self.encoder(x).transpose(1,2).unsqueeze(1)

class FrameTransformerDecoder(nn.Module):
    def __init__(self, channels, skip_channels, num_bands=4, cropsize=1024, bins=2048, feedforward_dim=2048, downsamples=0, bias=False, dropout=0.1):
        super(FrameTransformerDecoder, self).__init__()

        self.bins = bins
        self.cropsize = cropsize
        self.num_bands = num_bands
        self.in_project = nn.Linear(channels, 1, bias=bias)
        self.mem_project = nn.Linear(skip_channels, 1, bias=bias)
        self.decoder = nn.TransformerDecoderLayer(bins, num_bands, feedforward_dim, batch_first=True, norm_first=True, dropout=dropout)

    def __call__(self, x, mem, mask=None):
        x = self.in_project(x.transpose(1,3)).squeeze(3)
        mem = self.mem_project(mem.transpose(1,3)).squeeze(3)
        return self.decoder(tgt=x, memory=mem, tgt_mask=mask).transpose(1,2).unsqueeze(1)

The forward pass and loss calculation for training:

        src = src.to(device)
        tgt = tgt.to(device)

        with torch.cuda.amp.autocast_mode.autocast(enabled=grad_scaler is not None):
            pred = model(src, F.pad(tgt, (1,0))[:, :, :, :-1])

        loss = crit(src * pred, tgt)

The forward pass and loss calculation for validation (horrifically inefficient currently):

            src = src.to(device)
            tgt = tgt.to(device)

            mem = model.encode(src)        
            h = torch.zeros_like(src)

            for i in tqdm(range(src.shape[3])):
                h = model.decode(F.pad(src * h, (1,0))[:, :, :, :-1], mem=mem, idx=i)
 
            loss = crit(src * h, tgt)

If any further code is desired let me know, I do have this on github but its a more complex setup. I simplified the architecture to use PyTorch’s built in transformer modules for a sanity check and figure its easier to understand anyways.

Edit: it occurs to me this variant doesn’t have positional encoding… was using relative positional encoding in my last iteration which is why that wasn’t directly used here. Going to add fixed positional encoding and see if that changes anything.

Edit 2: Also occurs to me I’m passing in the mask rather than the expected mask * source…

Update: Validation loss was significantly better with the correct mask added and fixed positional encoding, however it only saw around 10% of my dataset (which is >1TB so takes hours). Will be running some tests while I work to see if this converges.

Update 2: Validation loss dropped yet again for the second epoch on another 10% of my dataset. This seems to point to one of two things being the issue: overfitting (I upped dropout from 0.1 to 0.4), or validation code passing the mask instead of the source times the mask in (which is pretty obviously going to hurt). Going to let it train a third epoch and see what happens, but it would always start getting worse by this point so I’m starting to become hopeful. Or potentially positional encoding, however the original architecture with relative positional encoding had these same training dynamics so I suspect not.

(Updated post with corrected code that at least seems to not completely break)

Don’t mean to continuously bump this, sorry if this is annoying (also didn’t realize editing bumped). I think the version with absolute positional encoding and fixed input to decoder might be learning a little, but not totally sure. In the resulting output vocal spectrogram, you can clearly see areas where it reacted to where vocals began and ended, and the supervised model appears to focus on similar areas of each frame, but the translator is doing quite poorly. It went through two full epochs and was about where it was the previous time. You can definitely hear vocals in the ‘extracted’ vocal spectrogram and it is the dominant audio, however it also removes other areas of the audio consistently as well (such as kick and snare and a lot of the high end from guitar). Still feel like I have an obvious mistake sitting in here somewhere, but it seems most likely that its in the training or validation code.

Will probably give pretraining an encoder only variant a try overnight and then add decoders to it and see if that helps with convergence, as I’m starting to think this is just going to take a lot longer to train. But I would be happy if someone came in and pointed out some rookie mistake I’m making haha, ultimately I’m doing this to learn and would prefer someone call something stupid out bluntly.

Edit: Third epoch just finished, validation loss dropped this time. It seems as though this will just take a lot longer to train, but I’m still open to someone pointing out any mistakes I’ve made or suggestions for improvement.

Hmm, so I trained this using a somewhat weird setup and it got weirdly accurate at reconstructing songs, so I’m guessing its cheating somehow. The audio quality suffered, but it was able to completely predict the audio in terms of song structure and notes and only the mix was off. The architecture I used for this test was a variant of the Primer with relative positional encoding from the music transformer. I feel quite confident there is a glaring issue in the code at this point that I am missing.

In case anyone gets any ideas of whats wrong, here’s the code I used for this test:

Transformer module:

class FrameTransformer(nn.Module):
    def __init__(self, channels, n_fft=2048, feedforward_dim=512, num_bands=4, num_encoders=1, num_decoders=1, cropsize=1024, bias=False, autoregressive=False, out_activate=ReLU1(), encoder_only=False):
        super(FrameTransformer, self).__init__()
        
        self.max_bin = n_fft // 2
        self.output_bin = n_fft // 2 + 1
        self.cropsize = cropsize
        
        self.encoder_only = encoder_only
        self.register_buffer('mask', torch.triu(torch.ones(cropsize, cropsize) * float('-inf'), diagonal=1))
        self.encoder = nn.ModuleList([FrameTransformerEncoder(channels + i, bins=self.max_bin, num_bands=num_bands, cropsize=cropsize, feedforward_dim=feedforward_dim, bias=bias) for i in range(num_encoders)])
        self.decoder = nn.ModuleList([FrameTransformerDecoder(channels + i, channels + num_encoders, bins=self.max_bin, num_bands=num_bands, cropsize=cropsize, feedforward_dim=feedforward_dim, bias=bias) for i in range(num_decoders)]) if not encoder_only else None
        self.out = nn.Linear(channels + (num_decoders if not encoder_only else num_encoders), 2, bias=bias)
        self.activate = out_activate if out_activate is not None else nn.Identity()

    def __call__(self, src, tgt=None):
        if self.encoder_only:
            out = self.encode(src)
        else:
            mem = self.encode(src)
            out = self.decode(tgt, mem=mem)

        return F.pad(
            input=self.activate(self.out(out.transpose(1,3)).transpose(1,3)),
            pad=(0, 0, 0, self.output_bin - self.max_bin),
            mode='replicate'
        )

    def encode(self, src):
        src = src[:, :, :self.max_bin]

        for module in self.encoder:
            t = module(src, mask=self.mask)
            src = torch.cat((src, t), dim=1)

        return src

    def decode(self, tgt, mem):
        tgt = tgt[:, :, :self.max_bin]

        for module in self.decoder:
            t = module(tgt, mem=mem, mask=self.mask)
            tgt = torch.cat((tgt, t), dim=1)

        return F.pad(
            input=self.activate(self.out(tgt.transpose(1,3)).transpose(1,3)),
            pad=(0, 0, 0, self.output_bin - self.max_bin),
            mode='replicate'
        )

Multihead attention (named multiband here due to context):

class MultibandFrameAttention(nn.Module):
    def __init__(self, num_bands, bins, cropsize, kernel_size=3):
        super().__init__()

        self.num_bands = num_bands

        self.q_proj = nn.Linear(bins, bins)
        self.q_conv = CausalConv1d(bins, bins, kernel_size=kernel_size, groups=bins)

        self.k_proj = nn.Linear(bins, bins)
        self.k_conv = CausalConv1d(bins, bins, kernel_size=kernel_size, groups=bins)

        self.v_proj = nn.Linear(bins, bins)
        self.v_conv = CausalConv1d(bins, bins, kernel_size=kernel_size, groups=bins)

        self.o_proj = nn.Linear(bins, bins)

        self.er = nn.Parameter(torch.empty(bins // num_bands, cropsize))
        nn.init.normal_(self.er)

    def forward(self, x, mem=None, mask=None):
        b,w,c = x.shape

        q = self.q_conv(self.q_proj(x).transpose(1,2)).transpose(1,2).reshape(b, w, self.num_bands, -1).permute(0,2,1,3)
        k = self.k_conv(self.k_proj(x if mem is None else mem).transpose(1,2)).transpose(1,2).reshape(b, w, self.num_bands, -1).permute(0,2,3,1)
        v = self.v_conv(self.v_proj(x if mem is None else mem).transpose(1,2)).transpose(1,2).reshape(b, w, self.num_bands, -1).permute(0,2,1,3)
        p = F.pad(torch.matmul(q,self.er), (1,0)).transpose(2,3)[:,:,1:,:]
        qk = (torch.matmul(q,k)+p) / math.sqrt(c)

        if mask is not None:
            qk = qk + mask

        a = F.softmax(qk, dim=-1)
        a = torch.matmul(a,v).transpose(1,2).reshape(b,w,-1)
        o = self.o_proj(a)
        return o

the causal conv used in the attention module:

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, bias=True):
        super(CausalConv1d, self).__init__()

        self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, kernel_size))
        self.bias = nn.Parameter(torch.empty(out_channels)) if bias else None
        self.kernel_size = kernel_size
        self.padding = kernel_size - 1
        self.groups = groups
        self.stride = stride
        
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        return F.conv1d(F.pad(x, (self.kernel_size - 1, 0)), weight=self.weight, bias=self.bias, stride=self.stride, groups=self.groups)

and the encoder (pretraining with an encoder only variant):

class FrameTransformerEncoder(nn.Module):
    def __init__(self, channels, bins, num_bands=4, cropsize=1024, feedforward_dim=2048, bias=False, dropout=0.1, autoregressive=False):
        super(FrameTransformerEncoder, self).__init__()

        self.bins = bins
        self.cropsize = cropsize
        self.num_bands = num_bands
        self.autoregressive = autoregressive

        self.in_project = nn.Linear(channels, 1, bias=bias)

        self.norm1 = nn.LayerNorm(bins)
        self.attn = MultibandFrameAttention(num_bands, bins, cropsize, kernel_size=3)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(bins)
        self.relu = nn.ReLU(inplace=True)
        self.linear1 = nn.Linear(bins, feedforward_dim, bias=bias)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(feedforward_dim, bins, bias=bias)
        self.dropout3 = nn.Dropout(dropout)

    def __call__(self, x, mask=None):
        x = self.in_project(x.transpose(1,3)).squeeze(3)

        h = self.norm1(x)
        h = self.attn(h, mask=mask)
        x = x + self.dropout1(h)

        h = self.norm2(x)
        h = self.linear2(self.dropout2(torch.square(self.relu(self.linear1(h)))))
        x = x + self.dropout3(h)

        return x.transpose(1,2).unsqueeze(1)

Loss calculation is slightly unorthodox here, using (sigmoid(x) * 2) - 1 to create an additive mask for the current frame so that it just has to learn a residual to transform the current frame into the next:

        src = src.to(device)
        tgt = tgt.to(device)

        with torch.cuda.amp.autocast_mode.autocast(enabled=grad_scaler is not None):
            pred = (model(src) * 2) - 1

        loss = crit(torch.relu(src + pred), tgt)

Hopefully someone sees some issue with this. Trying a variant without the causal convolutions as I’m running out of ideas lol. The predicted output is far from perfect, but I find it very difficult to believe that it could predict musical structure to any level of accuracy after two epochs even with the 1TB of audio data I used for this pretraining…

Edit: I guess due to the residual connection its probably just slowly learning to shift the audio forward. The audio isn’t perfectly aligned so it appears to be learning that slowly, makes sense given the residual realistically (also makes sense it’d learn somewhat slowly due to the fact that most frames are fairly close to each other in terms of their frequency bins).

I pretrained this model to predict the next frame directly (so its output was between 0 and 1) and then used the raw output as the prediction rather than as a mask. I trained in the semi-supervised fashion for only 10 epochs, but after doing so and starting supervised training with an lr warmup scheduler loss started off significantly lower (very noticeably) and trended down far more rapidly than any previous model. Used a chained lr scheduler with lr warmup and polynomial lr decay for the pretraining session but I think the decay was probably too long, was basically at 1e-4 when I stopped training at epoch 10 (more accurately something like 9.5e-05). Validation loss was still trending down with the pretraining session, but didn’t want to waste my time if it wasn’t working at all. Still have to verify validation loss with the supervised version, but training loss is hope inspiring at least. The pretrained model definitely can be taken way further and a larger model will likely do significantly better; the output was basically white noise with rhythm and some ghostly harmonic backings.

So anyways, looks like the original errors in this post were simple mistakes. Probably why I shouldn’t be coding at 6am while I have to be online for work by at latest 10am… Leaving this up in case any of this code helps some random person down the road (which seems farfetched but hey who knows)

I guess to update since I left that hanging: validation loss was lower with the model that was started off from the pretrained autoregressive model, however I’m not really convinced it was anything more than random weight initialization as it wasn’t significant; validation loss (MAE) was 0.001943 with the initial model while it was 0.001904 for the model that started from the pretrained checkpoint.

However, on that note, I also realize that pretraining in an autoregressive fashion here is kinda stupid. I decided to bert-ify things as much as I could and ended up writing a new dataset that randomly whites out frames of the spectrogram (setting each frequency bin to 1.0 as it would be fairly rare in music and can act as the mask token in bert does), masks the frames in the attention score matrix and then has the model try to predict what the whited-out frames are supposed to be. Unfortunately I think its overfitting because its getting a little too accurate after only the 1st epoch. Might need to spend this weekend collecting more data…

I would be curious if anyone has any ideas on pretraining for this use case as it does veer a bit from the typical path for transformers and I’m not sure what else I could do but feel as though I need more than just masked frame prediction for it to be learning in order for it to be an effective pretraining setup as with bert. I’m currently reading this paper [1912.10211] PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition (arxiv.org) but would be happy if anyone had any further reading recommendations or ideas.