i was given an optimizer that takes a function as an input that computes the gradient of a loss function.
I implemented three versions of the gradient function:
- computing the gradient by hand and implementing it in Numpy
- computing the gradient with JAX
- computing the gradient with Torch
The program for Torch looks a bit like this snippet here:
import torch import numpy as np X = np.random.rand(5, 5) Y = np.random.rand(4, 5) K = np.random.rand(4, 5) alpha = 0.1 def loss_grad(X, Y, K): X = torch.tensor(X, requires_grad=True) Y = torch.tensor(Y) K = torch.tensor(K) loss_val = 1/2 * torch.pow(torch.norm(Y-torch.matmul(K,X)),2) print(loss_val) loss_val.backward() return X.grad.numpy() # the function i am giving to the optimizer dfX = lambda x : loss_grad(x, Y, K) # this would be inside the optimizer for _ in range(100000): X = X - alpha * dfX(X)
My problem is now, that Torch and JAX compute exactly the same value for the loss function, but after three iterations i get NaN values in my gradient with Torch. This does not happen with JAX or the other implementation. The example above works fine for Torch, so i do not really know why i get those NaN values.