Using the gradients in a loss function

Hi.
Say I know both what the output of my network should be and what the gradients ought to be for each training example.
Is it possible to train a network in pytorch in which the loss depends on both the output of the network and the gradients of the network w.r.t. the network input?

inp = Variable(torch.rand(3, 4), requires_grad=True)
W = Variable(torch.rand(4, 4), requires_grad=True)
yreal = Variable(torch.rand(3, 4), requires_grad=False)
gradsreal = Variable(torch.rand(3, 4), requires_grad=True)

ypred = torch.matmul(inp, W)
ypred.backward(torch.ones(ypred.shape))
gradspred = inp.grad

loss = torch.mean((yreal - ypred) ** 2 + (gradspred - gradsreal) ** 2)
loss.backward()

This won’t work as inp.grad is volatile. Also would I have to zero all gradients after calculating the gradspred?

3 Likes

To make it working w/o much change, add create_graph=True, retain_graph=True to the backward call.

However, a more efficient way is:

gradspred, = autograd.grad(ypred, inp, 
                           grad_outputs=ypred.data.new(y_pred.shape).fill_(1),
                           create_graph=True)
loss = ...
loss.backward()
6 Likes

Great! Worked like a charm :smile:

Hello! Thank you for your answer. I used your instruction and did not understand what does code “grad_outputs=ypred.data.new(y_pred.shape).fill_(1)” do. Can you please explain it or point me to explanation?

take look at Assign manual assigned "grad_output"