Greetings everyone,
I’m trying to create a custom loss function with autograd (to use backward method). I’m using this example from Pytorch Tutorial as a guide:
PyTorch: Defining new autograd functions
I modified the loss function as shown in the code below (I added MyLoss & and applied it inside the loop):
import torch
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
#print(grad_output)
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
class MyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, y, y_pred):
ctx.save_for_backward(y, y_pred)
return (y_pred - y).pow(2).sum()
@staticmethod
def backward(ctx, grad_output):
yy, yy_pred = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input = 2.0*(yy_pred - yy)
return grad_input, None
dtype = torch.float
device = torch.device("cpu")
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(500):
relu = MyReLU.apply
myloss = MyLoss.apply
y_pred = relu(x.mm(w1)).mm(w2)
loss = myloss(y_pred, y)
print(t, loss.item())
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
When I applied the loss as shown in the snippet above, the results looks like this:
0 27354714.0
1 124100200.0
2 1318249856.0
3 75945402368.0
4 3991775354028032.0
5 1.0317169395036933e+30
6 inf
7 nan
...
495 nan
496 nan
497 nan
498 nan
499 nan
The loss shows as ‘nan’ until the end of the loop. Though, when I use the code from the tutorial (either with autograd or without autograd), the results looks like this:
0 30291468.0
1 25411112.0
2 25554632.0
3 26653730.0
.....
496 8.221517782658339e-05
497 8.073687786236405e-05
498 7.963352254591882e-05
499 7.827148510841653e-05
My Two questions are:
- Why the results are different? Am I doing something wrong?
- Out of curiosity, why “relu = MyReLU.apply” is placed inside the loop?
Thank you in advance!