Nested Modules and Backward Hooks

Hi all,
I am trying to implement a binary tree structure with edge weights. Each node has a weight and during forward the tree is extended dynamically. Since I want to calculate the gradients myself, I also wanted to use backward hooks. However, none of the backward hooks are called although they are registered. All parameters are registered, but I am not sure if the nested modules in A are stored correctly.

Actually, I wanted to pass a value during the backward pass starting at the leaves from the bottom up, i.e. a leaf calculates a value using it’s gradient, which is then passed to the parent node to calculate its gradient. Is there perhaps another way to do this? (I thought or hoped a hook can manage this…)

A minimal example of my code:

import torch
import torch.nn as nn

class A(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim):
        super(A, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.root = Leaf(self.output_dim)

    def forward(self, X):
        # Dynamically extend the tree
        self.root = self.root.split(self.input_dim, self.output_dim)
        # traverse the tree to compute yhat and end in the leaves
        return torch.matmul(X, self.root.internal_weight)


class Leaf(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.leaf_weight = nn.Parameter(torch.randn(output_dim), requires_grad=True)
        self.register_backward_hook(self.hook)

    def split(self, input_dim, output_dim):
        # create new leaf node and one parent node
        new_internal = Internal(input_dim)
        new_internal.left_leaf = self
        new_internal.right_leaf = Leaf(output_dim)
        return new_internal

    def hook(self, grad_input, grad_output):
        print("Hook")

class Internal(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.internal_weight = nn.Parameter(torch.randn(input_dim), requires_grad=True)
        self.left_leaf = None
        self.right_leaf = None
        self.hook = self.register_backward_hook(self.hook)

    def hook(self, grad_input, grad_output):
        print("Hook")

main function:

from torch.optim import Adam

lr = 0.05
a = A(5, 1)
criterion = nn.MSELoss()
optimizer = Adam(a.parameters(), lr=lr)
old_params = sum(1 for _ in a.parameters())
input = torch.randn(5)
target = torch.randn(1)

yhat = a(input)
# the number of parameters has changed, so 'update' the optimizer
if old_params != sum(1 for _ in a.parameters()):
    optimizer = Adam(a.parameters(), lr=lr)
    old_params = sum(1 for _ in a.parameters())
loss = criterion(yhat, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(dict(a.named_parameters()))
print(a.modules)

Output:

{‘root.internal_weight’: Parameter containing:
tensor([-0.0454, -1.3786, 0.1479, -0.3696, 1.6160], requires_grad=True), ‘root.left_leaf.leaf_weight’: Parameter containing:
tensor([2.3965], requires_grad=True), ‘root.right_leaf.leaf_weight’: Parameter containing:
tensor([0.6417], requires_grad=True)}
<bound method Module.modules of A(
(root): Internal(
(left_leaf): Leaf()
(right_leaf): Leaf()
)
)>

Thank you so much for reading and thinking about this!