Vanishing Gradients and how to fix?

I believe I am running into vanishing gradients and was hoping to see if anyone could help me either fix this or revive the model.

I am use a recreation of alphazero’s model but made in pytorch.

This started showing up a while ago in the iterations but I never caught it early enough and am looking to either fix it running forward and if needed to modify some weights now.

I noticed it when I was compiling my model into a tensorrt model via torch_tensorrt. It was giving me warnings that my model has weights smaller than float16 that would get converted to the lowest float16 value.

Upon inspection of my weights I have noticed that about 90% or more of them on my layers are near 0, thus my issue with vanishing weights.

From my understanding of vanishing weights, my model shouldn’t have vanishing weights, but it seem that it does.

  1. Can I fix this running forward? Anything wrong here?
  2. Can I just set all of my vanishing weights to a minimum value to jumpstart them? or maybe a range of minimum values? I understand this could temporarily mess up the model but I would hope that after a few iterations it could recover?

I recreated my first iteration to see if it was a problem early on or if it happened later and I noticed that when I use Adam optimizer with learning rate of .001 then I get vanishing gradients after the first pass.

When I used SGD with .2 learning rate then I do not get vanishing gradients on the first pass. I am not sure if this will hold up.

I initially had started with Adam but have recently changed to SGD (I read SGD generalizes better) which would explain why I now have vanished weights.

So additional questions now:
Can I just train on SGD and will that then recover the vanished weights?
Also, I’m curious why Adam would have caused this to be an issue to start with.

Is there an easy function to set all near zero weights to a range of small numbers?
Something like:
torch.nn.init.uniform_(policy.conv.conv1.weight, -.01, .01)
but only for near zero weights