Hey guys,
(This ended up being a wall of text to describe my setup, actual question is at bottom)
I’m working on, as best I can, trying to apply the BERT pretraining setup to my model I’ve been working on. The model itself is a transformer, where each frame is treated as a token, which makes use of the evolved transformer encoder architecture with the modifications from the Primer-EZ architecture. It also includes relative positional encoding as seen in the Music Transformer. The input to each transformer encoder is a linear bottleneck to extract a single channel representation of both frames; the output is a transformed frame that is concatenated with the input to form a DenseNet out of the transformer encoders, the output of which is then bottlenecked back to 2 channels at the end.
I have a custom data loader which takes in a collection of spectrograms (double the size trained on so it can cut out different slices of the spectrogram for more data). From here, the input spectrogram is split into groups of N frames. For the first task, the groups are treated as whole ‘words’ and they are selected at a rate of 15%. As with BERT, 80% of the selected mask tokens are whited/masked out by setting each frequency bin to 1.0, 10% of the time I take the max between random noise and the original value, and then 10% of the time it uses the original frames without modification. The neural network then is tasked with creating a multiplicative mask by passing the output through a sigmoid which allows it to either leave frames as they are or chisel away at them to create the expected audio.
For the second task, a separator ‘token’ (alternating 0s and 1s in frequency bins which is highly unlikely to occur naturally though maybe I’m not thinking about something obvious) is inserted in between the first and second halves of the spectrogram. 50% of the time the last half will be replaced with a random slice from a different spectrogram. The neural network then learns to predict whether the second half of the spectrogram is a continuation of the first half of the spectrogram. To facilitate this task, I blank out a half tokens worth of frames (16/2 frames in my case) at the end of the first half and beginning of the second half so that the transition area is not available for the neural network. For the separator token mentioned above, this is included in both the input tensor and target tensor.
Now, needless to say, there are some serious differences here compared to NLP. For one, embeddings in this architecture are the frames of the spectrogram itself, this works quite well in a strictly supervised fashion when trying to convert mixes with vocals to instrumental mixes but is definitely a major difference. Another is that spectrograms are far more fluid in nature than language and different frames will typically smoothly transition between each other (hence masking out larger chunks rather than individual frames otherwise it could probably learn some interpolation function and be reasonably effective I’d imagine).
So, this brings me to my question: Does anyone have any critiques or suggestions for this idea? I’d be more than happy to share pretrained checkpoints afterward in an open-source fashion and maybe have a little community microproject or something (I do have this on GitHub if anyone is interested though its highly experimental and changes rapidly, was forked from an MMDENSELSTM implementation and evolved over time). Currently have a model pretraining on thousands of albums (little hard to gauge how many songs at this point, well over 1TB though as the data has to rest on 3 different ssds, one of which is 2TB dedicated to just this with another 1TB ssd with 95% dedicated to this and a third 1TB external SSD with probably 400GB as more leads to latency as its usb-c)
My main goal is to use this for track separation, so this isn’t a commercial or academic endeavor; I just love instrumental music lol. Would also be open to being told I’m making some unsound judgments here and to have someone correct me on any of this, I’m a software engineer at a fintech company trying to learn as much as I can about machine learning and submersing myself in a somewhat challenging problem seems to be the best way to do that.
Edit: did some more testing, in case any of this gives anyone any ideas I actually changed the architecture to a u-net and am getting significantly higher quality with it. I use Nx1 kernel convolutions with a stride of 2x1 to only convolve features from the same frame of the spectrogram which embeds them into a lower dimensional space and adds locality along the pitch dimension which I imagine would be important due to the nature of sound i.e. octaves. It includes frame encoders which use the Nx1 kernel convs which are followed by a sequence of transformer encoders - a smaller number than usual, in my case 2 (since they are at every downsampling stage of the u-net which in this case means 5 stages so 10 encoders at different pitch scales for instance). After this, the following frame encoder will downsample it along the frequency dimension and embed more pitch locality into it while retaining the resolution on the temporal dimension. What’s interesting is that increasing the channel count at each downsampling doesn’t have a huge effect which seems to imply its the downsampling itself that is helpful, but I haven’t tested this out yet short of some basic tests with verifying that lowering channel count did not in fact hurt validation loss which is not what I would have expected. The slightly weird part is that before each frame decoder, I make use of the transformer decoder architecture instead (or at least, my hybrid variant of the Evolved Transformer/Primer/Music Transformer). For the memory, I use the u-nets skip connection which includes the output from the transformer encoders at that level, allowing it to effectively query for global information from the original representation at that level.
Need to actually test what makes the u-net version so much higher quality though, this is all mainly speculation regarding the octaves and pitch locality. The pure transformer variant works and works fairly well, however the u-net variant is significantly higher quality. Little weird having a u-net that downsamples only on one dimension, but a single u-net with just the frequency embedding outperforms a DenseNet setup with three u-nets using the frequency and temporal embedding setup even when it includes the transformer modules. Kinda interesting. Will be pretraining it on my dataset over this next week and will likely update this post with a pretrained checkpoint if things go well, though its currently 205M parameters so for training requires 12gb of VRAM. Pretty excited to see how it turns out, though I worry that the probability distributions that BERT predicts are where it really gets its power from…