Hello! I have this (simplified version) code:
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 1, bias=False)
self.linear2 = nn.Linear(1, 2, bias=False)
def forward(self, x):
z = self.linear1(x)
y_pred = self.linear2(z)
return y_pred, z
for epoch in range(1000):
model.train()
for i, dt in enumerate(data.trn_dl):
optimizer.zero_grad()
output = model(dt[0])
output[1].register_hook(save_grad('z_x_hat'))
output[0][0].backward(retain_graph=True)
z_x_hat = grads['z_x_hat']
output[1].register_hook(save_grad('z_y_hat'))
output[0][1].backward(retain_graph=True)
z_y_hat = grads['z_y_hat']
z_x_hat.requires_grad = True
z_y_hat.requires_grad = True
loss = abs(torch.sqrt(z_x_hat**2+z_y_hat**2)-1)
print(z_x_hat,z_y_hat,loss)
loss.backward()
optimizer.step()
So my network has 2 outputs (x_hat, y_hat) and I want the loss function to be given by the derivative of the outputs with respect to the value of z (I need this as part of something more complex for the actual project, but this is the part where I am stuck). In this simple case, the derivatives are just the value of the weights from z to x_hat and z to y_hat (in the real case is the product of all the partial derivatives from x_hat to z). Using the code above:
output[1].register_hook(save_grad('z_x_hat'))
output[0][0].backward(retain_graph=True)
z_x_hat = grads['z_x_hat']
I am able to get the value of this derivative (and hence the weight), but the problem is that I need to use retain_graph=True
in order to do that, and that screws up my whole backprop, as I want to optimize the network using backprop only on the loss at the end, not on these intermediate calls. Can someone help me with this? Thank you!