PyTorch gradient differs from JAX


i was given an optimizer that takes a function as an input that computes the gradient of a loss function.
I implemented three versions of the gradient function:

  1. computing the gradient by hand and implementing it in Numpy
  2. computing the gradient with JAX
  3. computing the gradient with Torch

The program for Torch looks a bit like this snippet here:

import torch
import numpy as np

X = np.random.rand(5, 5)
Y = np.random.rand(4, 5)
K = np.random.rand(4, 5)

alpha = 0.1

def loss_grad(X, Y, K):
    X = torch.tensor(X, requires_grad=True)
    Y = torch.tensor(Y)
    K = torch.tensor(K)

    loss_val = 1/2 * torch.pow(torch.norm(Y-torch.matmul(K,X)),2)


    return X.grad.numpy()

# the function i am giving to the optimizer
dfX = lambda x : loss_grad(x, Y, K)
# this would be inside the optimizer
for _ in range(100000):
    X = X - alpha * dfX(X)

My problem is now, that Torch and JAX compute exactly the same value for the loss function, but after three iterations i get NaN values in my gradient with Torch. This does not happen with JAX or the other implementation. The example above works fine for Torch, so i do not really know why i get those NaN values.


You question is a bit confusing between which implementations you have and which match and which don’t.

The example you gave works fine for me.
What most likely happens is that you hit a point of non-differentiability of your function and that torch and jax behave differently at this point.
Can you check what is the point where your function becomes nan? Also you can try to enable the anomaly mode in pytorch to see which op created the nans.