Heya,

There is an increasing amount of research (meta-learning, deep equilibrium models, OptNets etc.) where the forward pass of some model includes solving a seperate optimization problem.

In many cases, we don’t need to track this *inner* optimization problem as part of the same graph as our main model (e.g. because we derive gradients manually and add them as a hook).

The question is this: how can I use an entire sub *inner* graph that doesn’t affect or know about the *outer* graph it is part of. Here is some pseudo code:

```
class OuterModel(nn.Module):
def __init__(self):
super(OuterModel, self).__init__()
self.tail = nn.Sequential(...)
self.head = nn.Sequential(...)
def forward(self, x):
tail_out = self.tail(x) #tracked in outer graph
with torch.no_grad(): #everyhting happening here should be tracked in some inner graph only
inner_params = torch.zeros(requires_grad=True)
inner_optim = torch.optim.AdamW([inner_params])
for i in range(inner_steps):
inner_loss = torch.norm(A@inner_params - tail_out.detach())
inner_loss.backward() #this should only affect inner params
inner_optim.step()
inner_optim.zero_grad()
inner_params.detach_() #outer graph doesn't need to track these
out = tail_out + inner_params * self.head(tail_out)
return out
## forward/ backward pass of outer model
outer_model = OuterModel()
outer_optim = torch.optim.AdamW(outer_model.parameters())
for x,y in loader:
x,y = data_batch
out = outer_model(x)
outer_loss = (out-y)
outer_loss.backward() #only does backward for outer params
outer_optim.step()
outer_optim.zero_grad()
```

The easiest way is to do this manually: `inner_grad = torch.autograd.grad(inner_loss, inner_params)`

and then `inner_params -= 0.1*inner_grad`

but ideally we want to leverage all the nice optimizers pytorch offers out of the box and and we don’t want to concatenate all our inner_params into one tensor. I’ve also tried to play with `inner_loss.backward(retain_graph=True)`

to retain the outer graph, but this doesn’t train (I suspect because the inner graph is also retained or something)

Thanks for any pointers!

Paul