How to debug NaN gradients in custom Cosine Normalization implementation?

I’ve implemented cosine normalization [1], attempting to match the paper [2].

However, when I use this code, I end up getting NaNs in when calling .backward() after a few thousand samples presented to my network (exact count varies). Without using CosNorm (either with stock ResNet, or just switching to BatchNorm), the problem seems to go away (can’t prove a negative, though, since it might just be slower to happen).

My primary question is: How can I reliably investigate issues like this? I registered hooks, and nothing passed in there had NaNs, and so I eventually resorted to randomly disabling or changing parts of the code until I got enough clues to find a culprit (and even now, all the evidence I have is circumstantial).

My secondary question is: Any ideas why this normalization would be unstable? I was exclusively wrapping Conv2d instances.


  2. .

if x.norm or w.norm is 0, then you essentially have 1e-5 in the denominator.

If you pass that gradient through a couple of BatchNorm layers’ backwards, then it might start generating NaNs, and not at the CosNorm itself.

While I wasn’t using BatchNorm in the network when using CosNorm, I can see the general point of small numbers divided by other small numbers might be unstable.

Adding an additional + 1.0 to the denominator has been stable for about 40 epochs of training now, though I can’t say if it’s had any significant impact on the ability of the layer to normalize or not. I worry that repeated layers of this will end up pushing the norm of x towards zero.

I whipped up a hack of a script to see the numerical effects, and it seems stable for a variety of weight scalings and normalization constants, though:

import random

x = random.random()
w = random.random() + 0.5

def simcos(x, layer_count):
    avg_list = [[] for n in range(layer_count)]

    for n in range(100):
        w = [(random.random() + 0.5) * 3 for n in range(layer_count)]
        xin = x
        for i in range(layer_count):
            xout = xin * w[i] * random.random()
            xnorm = xout * 2 / (xin * w[i] + 1)
            xin = xnorm

    return [sum(l)/100 for l in avg_list]

for x in [0.1, 0.5, 0.9, 1.0, 1.5, 2.0, 10.0]:
    print(x, simcos(x, 50)[::5])