How to replace the gradient of initial tensor using backward hooks?

Below I have a simple fixed point layer that simply iterates a function forward for six steps. In the first implementation, FixedPoint1, I backpropagate through the entire forward iteration. In the second implementation, FixedPoint2, I want to backpropagate only one step back, and use the gradients from single-step backprop as the gradients into self.initial. However, it seems that my implementation here does not actually update the gradients into self.initial. Concretely, I want the gradient of the loss with respect to self.initial to be the same as the gradient of the loss with respect to second_to_last_x. What would be the proper way of doing this?

import torch
import torch.nn as nn
import torch.nn.functional as F

def visualize_parameters(model):
    for n, p in model.named_parameters():
        try:
            if p.grad is None:
                print('{}\t{}'.format(n, None))
            else:
                print('{}\t{}'.format(n, p.grad.data.norm()))
        except:
            print(f'Could not get grad for {n}')

class FixedPoint1(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.initial = nn.Parameter(torch.rand(dim))
        self.step = nn.Linear(dim, dim)
        self.hooks = []

    def forward(self, steps=5):
        x = self.initial
        for i in range(steps):
            x = self.step(x)
        second_to_last_x = x
        new_x = self.step(second_to_last_x)
        return new_x

class FixedPoint2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.initial = nn.Parameter(torch.rand(dim))
        self.step = nn.Linear(dim, dim)
        self.hooks = []

    def forward(self, steps=5):
        x = self.initial

        with torch.no_grad():
            for i in range(steps):
                x = self.step(x)

        second_to_last_x = x.requires_grad_()
        self.hooks.append(second_to_last_x.register_hook(lambda grad: print(f'gradient of second to last: {grad.norm()}')))
        new_x = self.step(second_to_last_x)

        def backward_hook(grad):
            if any([hook is not None for hook in self.hooks]):
                for hook in self.hooks:
                    hook.remove()
            new_grad = torch.autograd.grad(new_x, x, grad)[0]
            return grad

        self.hooks.append(self.initial.register_hook(backward_hook))
        return new_x

def test_fixedpoint_1():
    torch.manual_seed(0)
    dim = 10
    target = torch.rand(dim)
    fixedpoint = FixedPoint1(dim)

    loss = F.mse_loss(fixedpoint(), target)
    loss.backward()
    visualize_parameters(fixedpoint)

def test_fixedpoint_2():
    torch.manual_seed(0)
    dim = 10
    target = torch.rand(dim)
    fixedpoint = FixedPoint2(dim)

    loss = F.mse_loss(fixedpoint(), target)
    loss.backward()
    visualize_parameters(fixedpoint)

print('FixedPoint1:')
test_fixedpoint_1()
print('\n')
print('FixedPoint2:')
test_fixedpoint_2()

gives the output:

FixedPoint1:
initial	0.014008551836013794
step.weight	0.26398196816444397
step.bias	0.3846798539161682

FixedPoint2:
gradient of second to last: 0.1674455851316452
initial	None    ## I want this to be 0.1674455851316452 as well
step.weight	0.2527035176753998
step.bias	0.3748074173927307

@mbchang Wanted to understand the input data and model architecture

  • You have a parameter initial with which you start

  • You apply a with torch.no_grad() and perform Linear operations on x. This causes x to break away from the graph and loses its connections with self.initial

  • Due to the above, grads to not even flow back to initial

One of the solutions from my side is

import torch
import torch.nn as nn
import torch.nn.functional as F

def visualize_parameters(model):
    for n, p in model.named_parameters():
        try:
            if p.grad is None:
                print('{}\t{}'.format(n, None))
            else:
                print('{}\t{}'.format(n, p.grad.data.norm()))
        except:
            print(f'Could not get grad for {n}')
    
class FixedPoint2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.initial = nn.Parameter(torch.rand(dim))
        self.step = nn.Linear(dim, dim)
        self.hooks = None
        
    def assign_grad(self, grad):
        self.hooks = grad
        return grad
        
    def forward(self, steps=5):
        x = self.initial
        #with torch.no_grad():
        for i in range(steps):
            x = self.step(x)
        
        second_to_last_x = x.requires_grad_()
        second_last_hook = second_to_last_x.register_hook(lambda grad: self.assign_grad(grad))
        new_x = self.step(second_to_last_x)
        return new_x

def adjust_initial_grad(grad):
    return fixedpoint.hooks
    
grads = None
torch.manual_seed(0)
dim = 10
target = torch.rand(dim)
fixedpoint = FixedPoint2(dim)
initial_backward_hook = fixedpoint.initial.register_hook(lambda x: adjust_initial_grad(x))
initial_backward_hook = fixedpoint.initial.register_hook(lambda x: adjust_initial_grad(x))
loss = F.mse_loss(fixedpoint(), target)
loss.backward()
visualize_parameters(fixedpoint)