Matrix multiplication is resulting in NaN values during backpropagation

I am trying to make a simple Taylor series layer for my neural network but am unable to test it out because the weights become NaNs on the first backward pass.

Here is the code:

class Maclaurin(nn.Module):
    """ Maclaurin Series Layer  First Draft """
    def __init__(self):
        super().__init__()
        weights = torch.Tensor(1, 30)
        bias = torch.Tensor([64])
        self.bias = nn.Parameter(bias)
        scal = torch.arange(0, 30)  # powers of taylor series
        self.scal = scal.to(device)

        # initialize weights & biases
        nn.init.kaiming_uniform_(weights, a=math.sqrt(5)) # weight init
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)  # bias init
        weights = torch.transpose(weights, 0, 1)
        self.weights = nn.Parameter(weights)

    def forward(self, x):
        xr = x.repeat(1, 30)  # extend values to create Maclaurin series for each point
        xr = torch.pow(xr, self.scal)  # raise each term in series to proper power
        wx = torch.mm(xr, self.weights)  # multiply columns by weights
            
        return torch.add(wx, self.bias)  # w times x + b


class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_out):
        super(NeuralNet, self).__init__()
        self.taylor = Taylor()
        self.taylor.cuda()

    def forward(self, x):
        out = self.taylor(x)
        return out

I checked all of the sizes of the tensors for consistency here:

                            x    powers  weights
initial             [64, 1]       [30]    [30, 1]
x.repeat      [64, 30]       [30]    [30, 1]
powers       [64, 30]       [30]    [30, 1]
x*weights     [64, 1]       [30]    [30, 1]
sum rows       [64]         [30]    [30, 1]

Here is the output when I use torch.autograd.set_detect_anomaly(True) (I removed the file info).

wx = torch.mm(xr, self.weights)
 (function _print_stack)
Traceback (most recent call last):
  File "", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "", line 1, in <module>
    runfile("')
  File "", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "", line 113, in <module>
    loss.backward()         # backpropagation, compute gradients
  File "", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "", line 147, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'MmBackward' returned nan values in its 1th output.

Checking for infinite values in the gradient gave me this

taylor.bias tensor(True, device='cuda:0')
taylor.weights tensor(False, device='cuda:0')

So even though it registers bias goes to infinity on the first iteration but it seems like the problem lies with the torch.mm function if you look at the anomalies reported during backpropagation.

In some previous iterations, I got rid of all of these problems by leaving the weights as a row vector and multiplying them by a diagonal matrix before doing matrix multiplication, but then the network stopped learning the weights and learns the bias normally.

Notes:

  • I already tried a very small learning rate. I don’t think this is the problem.
  • bias and half of the weights are becoming NaNs by the second iteration, all of the weights are NaNs by the third
  • Even though most loss functions seem to have this problem some like torch.nn.SmoothL1Loss() do not (as long as the number of terms in the series is less than 40) so it would be interesting to see if it had something to do with the loss functions

Don’t use torch.Tensor to initialize the parameters, as it’s usage is deprecated and undocumented.
Depending what input you are passing to Tensor you might get unexpected results as seen here:

# initializes the tensor with the value 64 as a FloatTensor
x = torch.Tensor([64])
print(x)
> tensor([64.])

# creates an uninitialized FloatTensor with the shape 64
x = torch.Tensor(64)
print(x)
> tensor([-2.2953e-03,  3.0882e-41, -3.0420e-03,  4.5762e-41,  8.9683e-44,
           0.0000e+00,  1.1210e-43,  0.0000e+00, -2.2953e-03,  3.0882e-41,
          -2.2955e-03,  3.0882e-41,  1.4013e-45,  3.0882e-41,  2.8026e-44,
           2.1019e-44,  2.8026e-44,  3.7835e-44,  2.8026e-45,  0.0000e+00,
          -2.2955e-03,  3.0882e-41, -2.2955e-03,  3.0882e-41,  2.8026e-44,
           0.0000e+00,  2.4803e-43,  0.0000e+00, -2.2962e-03,  3.0882e-41,
          -3.0420e-03,  4.5762e-41,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  2.8026e-44,  0.0000e+00, -4.1769e-05,  3.0882e-41,
           0.0000e+00,  0.0000e+00, -7.4630e-04,  4.5762e-41,  2.8026e-44,
           4.3440e-44,  1.3593e-43,  0.0000e+00, -2.2961e-03,  3.0882e-41,
          -2.2953e-03,  3.0882e-41,  2.3822e-44,  4.5762e-41, -2.2955e-03,
           3.0882e-41, -2.2955e-03,  3.0882e-41, -2.2955e-03,  3.0882e-41,
           2.8026e-44,  2.1019e-44,  4.6243e-44,  0.0000e+00])

Use the factory methods instead (e.g. torch.randn, torch.empty, torch.tensor).
I don’t know if you want to initialize the self.bias using a single value, but this might be the issue in your code.

2 Likes

Thank you for responding!
That fixed 3 other problems I did not know I had but the original problem remains. Is it possible that the loss is too great and that is why this is happening?

I’m not sure whether I understand this correctly, but so basically, once you have inf, you are most likely in trouble:
When you get inf bias, one might expect an inf gradient to also propagate to the earlier operations (because for a + b, both a and b get the same gradients as the sum).
But once you have inf and matrix multiplications with varying signs, you very soonish get +inf and -inf which gets you NaN in further additions (such as in matrix multiplication reductions).

The other bit (you show MacLaurin and use Taylor, but hey) is that 30 elements of a Taylor series is… shall I say ambitious? Quite likely, this has the potential to be tricky w.r.t. avoiding inf and nan.

Best regards

Thomas

1 Like

Haha. yeah, I guess I copied code while I was in the middle of making changes. From what I can understand though you are absolutely right. It looks like the number of terms I was using let it start out with some pretty steep curves so the loss was astronomical. I fixed the problem by having fewer terms in the series and initializing parameters in such a way as to reduce the initial loss.

Thank you

Ethan