Hey,
i was given an optimizer that takes a function as an input that computes the gradient of a loss function.
I implemented three versions of the gradient function:
- computing the gradient by hand and implementing it in Numpy
- computing the gradient with JAX
- computing the gradient with Torch
The program for Torch looks a bit like this snippet here:
import torch
import numpy as np
X = np.random.rand(5, 5)
Y = np.random.rand(4, 5)
K = np.random.rand(4, 5)
alpha = 0.1
def loss_grad(X, Y, K):
X = torch.tensor(X, requires_grad=True)
Y = torch.tensor(Y)
K = torch.tensor(K)
loss_val = 1/2 * torch.pow(torch.norm(Y-torch.matmul(K,X)),2)
print(loss_val)
loss_val.backward()
return X.grad.numpy()
# the function i am giving to the optimizer
dfX = lambda x : loss_grad(x, Y, K)
# this would be inside the optimizer
for _ in range(100000):
X = X - alpha * dfX(X)
My problem is now, that Torch and JAX compute exactly the same value for the loss function, but after three iterations i get NaN values in my gradient with Torch. This does not happen with JAX or the other implementation. The example above works fine for Torch, so i do not really know why i get those NaN values.