Nesting an autograd graph inside another


There is an increasing amount of research (meta-learning, deep equilibrium models, OptNets etc.) where the forward pass of some model includes solving a seperate optimization problem.

In many cases, we don’t need to track this inner optimization problem as part of the same graph as our main model (e.g. because we derive gradients manually and add them as a hook).

The question is this: how can I use an entire sub inner graph that doesn’t affect or know about the outer graph it is part of. Here is some pseudo code:

class OuterModel(nn.Module):
    def __init__(self):
        super(OuterModel, self).__init__()
        self.tail = nn.Sequential(...)
        self.head = nn.Sequential(...)

    def forward(self, x):
        tail_out = self.tail(x) #tracked in outer graph

        with torch.no_grad(): #everyhting happening here should be tracked in some inner graph only
            inner_params = torch.zeros(requires_grad=True)
            inner_optim = torch.optim.AdamW([inner_params])

            for i in range(inner_steps):
                inner_loss = torch.norm(A@inner_params - tail_out.detach())
                inner_loss.backward() #this should only affect inner params

            inner_params.detach_() #outer graph doesn't need to track these

        out = tail_out + inner_params * self.head(tail_out) 
        return out

## forward/ backward pass of outer model 
outer_model = OuterModel()
outer_optim = torch.optim.AdamW(outer_model.parameters())

for x,y in loader:
    x,y = data_batch
    out = outer_model(x)
    outer_loss = (out-y)

    outer_loss.backward() #only does backward for outer params

The easiest way is to do this manually: inner_grad = torch.autograd.grad(inner_loss, inner_params) and then inner_params -= 0.1*inner_grad but ideally we want to leverage all the nice optimizers pytorch offers out of the box and and we don’t want to concatenate all our inner_params into one tensor. I’ve also tried to play with inner_loss.backward(retain_graph=True) to retain the outer graph, but this doesn’t train (I suspect because the inner graph is also retained or something)

Thanks for any pointers!