Nested Modules and Backward Hooks

Hi all,
I am trying to implement a binary tree structure with edge weights. Each node has a weight and during forward the tree is extended dynamically. Since I want to calculate the gradients myself, I also wanted to use backward hooks. However, none of the backward hooks are called although they are registered. All parameters are registered, but I am not sure if the nested modules in A are stored correctly.

Actually, I wanted to pass a value during the backward pass starting at the leaves from the bottom up, i.e. a leaf calculates a value using it’s gradient, which is then passed to the parent node to calculate its gradient. Is there perhaps another way to do this? (I thought or hoped a hook can manage this…)

A minimal example of my code:

import torch
import torch.nn as nn

class A(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim):
        super(A, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.root = Leaf(self.output_dim)

    def forward(self, X):
        # Dynamically extend the tree
        self.root = self.root.split(self.input_dim, self.output_dim)
        # traverse the tree to compute yhat and end in the leaves
        return torch.matmul(X, self.root.internal_weight)


class Leaf(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.leaf_weight = nn.Parameter(torch.randn(output_dim), requires_grad=True)
        self.register_backward_hook(self.hook)

    def split(self, input_dim, output_dim):
        # create new leaf node and one parent node
        new_internal = Internal(input_dim)
        new_internal.left_leaf = self
        new_internal.right_leaf = Leaf(output_dim)
        return new_internal

    def hook(self, grad_input, grad_output):
        print("Hook")

class Internal(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.internal_weight = nn.Parameter(torch.randn(input_dim), requires_grad=True)
        self.left_leaf = None
        self.right_leaf = None
        self.hook = self.register_backward_hook(self.hook)

    def hook(self, grad_input, grad_output):
        print("Hook")

main function:

from torch.optim import Adam

lr = 0.05
a = A(5, 1)
criterion = nn.MSELoss()
optimizer = Adam(a.parameters(), lr=lr)
old_params = sum(1 for _ in a.parameters())
input = torch.randn(5)
target = torch.randn(1)

yhat = a(input)
# the number of parameters has changed, so 'update' the optimizer
if old_params != sum(1 for _ in a.parameters()):
    optimizer = Adam(a.parameters(), lr=lr)
    old_params = sum(1 for _ in a.parameters())
loss = criterion(yhat, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(dict(a.named_parameters()))
print(a.modules)

Output:

{‘root.internal_weight’: Parameter containing:
tensor([-0.0454, -1.3786, 0.1479, -0.3696, 1.6160], requires_grad=True), ‘root.left_leaf.leaf_weight’: Parameter containing:
tensor([2.3965], requires_grad=True), ‘root.right_leaf.leaf_weight’: Parameter containing:
tensor([0.6417], requires_grad=True)}
<bound method Module.modules of A(
(root): Internal(
(left_leaf): Leaf()
(right_leaf): Leaf()
)
)>

Thank you so much for reading and thinking about this!

In case anyone is interested in a solution:
For me it worked with a torch.nn.ParameterDict to register all parameter correctly.

class A(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim):
        super(A, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.root = Leaf(self.output_dim)
        self.weights = torch.nn.ParameterDict({'leaf' : nn.Parameter(torch.randn(output_dim), requires_grad=True)})

Both the Leaf and the Internal node classes inherit no longer from nn.Module and their parameter are shifted to the ParameterDict in class A.
To extend the tree dynamically it worked to add and delete parameter dynamically to the ParameterDict.
I hope this helps someone!