Actually I’m trying to do something similar to you as in here
I believe what you can do is something like this:
import torch
class MyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, y, y_pred):
ctx.save_for_backward(y, y_pred)
###### do what ever you want in here and then return that neumrical valu
return (y_pred - y).pow(2).sum()
@staticmethod
def backward(ctx, grad_output):
yy, yy_pred = ctx.saved_tensors
grad_input = grad_output.clone()
##### return some gradient in here
grad_input = (yy_pred) * 2.0
return grad_input, None
dtype = torch.float
device = torch.device("cpu")
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(500):
myloss = MyLoss.apply
y_pred = x.mm(w1).mm(w2)
loss = myloss(y_pred, y)
print(t, loss.item())
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
In the example above, I’m passing both y & y_predict to the forward function, doing some operations to compute cost and then saving them. In the backward loop, I take the values of y & y_predict to compute some gradient since if the gradient is 0, everything will be zero. In the code above I’m just multiplying by 2.0 and returning that as a gradient (as an example only but you can do more operations)… I’m still waiting for someone to answer my question to ensure that my implementation is 100% correct…