Help training Titan+MIRAS, model learns to cheat loss

Hi, I’m training a small Titans-MIRAS model in my own setup as a hobby and I’m hitting some walls. The training should be very similar to any small transformer model, and I’m following what I find online.
The model is a 270M parameters model with 16 layers depth, embedding size of 1024 and 16 heads. Some parameters are added for the gates for the neural memory.
I’m using a portion of minipile, wikitext and dolly as datasets. Due to my limited hardware (a humble 3060 with 12gb memory) I’m starting with a 512 max sequence and 256 windo wsize (meaning that the model itself has only 256 contex length, and the neural memory is being trained to recall the information of the past 256 tokens, which still uses as much memory in the training as 512 length). Batch size is 8 with gradient accumulation at 8, so 64 examples per step. Starting LR is around 5e-4, I used the lr based on some gpt-2 trainings. Using cosine decay also caused the LR to be that high for almost all of the steps of the training, up until when the training reached loss 2 and I stopped (for the reasons explained below). I’m also using weight decay of 0.1, with adamw8bit, and grad clipping to 1. Warmup was about 300 steps long.

The problem is that the model learns to repeat a lot of words, right when it’s starting to understand the link between words, after learning basic grammar rules. The loss actually goes down between 3 and 1, which usually means that the model is generating actual sentences close to the training dataset, but it’s not. The quality of the output does not get better when the loss goes from 4 to 2. AFAIK, stuttering should be normal when loss is above 3 or 4, not at 2. This means that the model is learning to cheat the loss.

I tried increasing the batch size (via grad accum) to 256, but it doesn’t get better, it only slow things down. Also tried lower lr, same result.
Some examples of generations:
Step 3500 (64 batch size), stopped training because loss was extremely high for the quality of generation:
Loss around 2 (oscillates between 1 and 3):
<start_of_turn>user
Explain AI.<end_of_turn>
<start_of_turn>model
Mon Sc interaction which addition that combined changes that places and times only twice too too anywhere fewer too too yet too yet needs needs others others ways but they … nevertheless might typically though instead they instead likely when that claim that that may members may the that the that the that competitor that and the that the that the the the the

As you can see there are a lot of repetitions. The model is cheating.
Another generation, step 2650 (loss between 2 and 4.5, averages at 3):

<start_of_turn>user
Explain AI.<end_of_turn>
<start_of_turn>model
Banking 2 and NisioWiFias’s genes globally Developing AI skills to make games difficult to predict. KuneOSOSOS can be able to the role of the role of the Blue Blue B in Baye in Rotors. With multi-competational Agents that results in players from

Here the model actually understood the topic, as expected generated random sentences but still relative to AI.

Looking at the generations during the training, the model still repeated a lot of tokens before and after step 2650, but at 2650 it was getting better, then it started to stutter again. What’s happening?