Deep double descent with ADAMW

I hope this is an appropriate place to bring this up. I am using pytorch but it is more of a question related to the training methodology. Like I mentioned in the title I am trying to make a double descent network and I am using ADAMW to optimize my weights.

I managed to get to near 0 loss with average loss oscillating around 1e-4 / 1e-5. And I am somewhat stuck. Stuck because I keep cosine annealing with diminishing returns to the weight norm of all my weights and my network is still outputting garbage. Somewhat because I noticed that I can keep increasing weight decay above 1 and I am still having the loss reconverge towards 0 instead of network completely breaking down.

I just wanted to confirm it, if I am understanding it correctly. I am thinking that loss convergence still being possible with weight decay above 1 is because at very small weight magnitude the denomiator of moving averages of my gradient has become very small and now moving averages part of ADAM algorithm dominates the update step? It is like the moving averages can still update and compensate instead of completely blowing up, which happens when the denomiator tends to 0, but they are also updating weights much more than current gradient update and weight decay.

That would also mean that for ADAMW the “optimal” weight decay for normalization at small weight magnitude is constantly moving? With that in mind I did try switching to pure SGD, but there even with a reasonable weight decay (at least compared to what I am doing now) of 0.3 my loss starts diverging instantly. Am I reasoning correctly here?