Below I have a simple fixed point layer that simply iterates a function forward for six steps. In the first implementation, FixedPoint1, I backpropagate through the entire forward iteration. In the second implementation, FixedPoint2, I want to backpropagate only one step back, and use the gradients from single-step backprop as the gradients into self.initial
. However, it seems that my implementation here does not actually update the gradients into self.initial
. Concretely, I want the gradient of the loss with respect to self.initial
to be the same as the gradient of the loss with respect to second_to_last_x
. What would be the proper way of doing this?
import torch
import torch.nn as nn
import torch.nn.functional as F
def visualize_parameters(model):
for n, p in model.named_parameters():
try:
if p.grad is None:
print('{}\t{}'.format(n, None))
else:
print('{}\t{}'.format(n, p.grad.data.norm()))
except:
print(f'Could not get grad for {n}')
class FixedPoint1(nn.Module):
def __init__(self, dim):
super().__init__()
self.initial = nn.Parameter(torch.rand(dim))
self.step = nn.Linear(dim, dim)
self.hooks = []
def forward(self, steps=5):
x = self.initial
for i in range(steps):
x = self.step(x)
second_to_last_x = x
new_x = self.step(second_to_last_x)
return new_x
class FixedPoint2(nn.Module):
def __init__(self, dim):
super().__init__()
self.initial = nn.Parameter(torch.rand(dim))
self.step = nn.Linear(dim, dim)
self.hooks = []
def forward(self, steps=5):
x = self.initial
with torch.no_grad():
for i in range(steps):
x = self.step(x)
second_to_last_x = x.requires_grad_()
self.hooks.append(second_to_last_x.register_hook(lambda grad: print(f'gradient of second to last: {grad.norm()}')))
new_x = self.step(second_to_last_x)
def backward_hook(grad):
if any([hook is not None for hook in self.hooks]):
for hook in self.hooks:
hook.remove()
new_grad = torch.autograd.grad(new_x, x, grad)[0]
return grad
self.hooks.append(self.initial.register_hook(backward_hook))
return new_x
def test_fixedpoint_1():
torch.manual_seed(0)
dim = 10
target = torch.rand(dim)
fixedpoint = FixedPoint1(dim)
loss = F.mse_loss(fixedpoint(), target)
loss.backward()
visualize_parameters(fixedpoint)
def test_fixedpoint_2():
torch.manual_seed(0)
dim = 10
target = torch.rand(dim)
fixedpoint = FixedPoint2(dim)
loss = F.mse_loss(fixedpoint(), target)
loss.backward()
visualize_parameters(fixedpoint)
print('FixedPoint1:')
test_fixedpoint_1()
print('\n')
print('FixedPoint2:')
test_fixedpoint_2()
gives the output:
FixedPoint1:
initial 0.014008551836013794
step.weight 0.26398196816444397
step.bias 0.3846798539161682
FixedPoint2:
gradient of second to last: 0.1674455851316452
initial None ## I want this to be 0.1674455851316452 as well
step.weight 0.2527035176753998
step.bias 0.3748074173927307