Heya,
There is an increasing amount of research (meta-learning, deep equilibrium models, OptNets etc.) where the forward pass of some model includes solving a seperate optimization problem.
In many cases, we don’t need to track this inner optimization problem as part of the same graph as our main model (e.g. because we derive gradients manually and add them as a hook).
The question is this: how can I use an entire sub inner graph that doesn’t affect or know about the outer graph it is part of. Here is some pseudo code:
class OuterModel(nn.Module):
def __init__(self):
super(OuterModel, self).__init__()
self.tail = nn.Sequential(...)
self.head = nn.Sequential(...)
def forward(self, x):
tail_out = self.tail(x) #tracked in outer graph
with torch.no_grad(): #everyhting happening here should be tracked in some inner graph only
inner_params = torch.zeros(requires_grad=True)
inner_optim = torch.optim.AdamW([inner_params])
for i in range(inner_steps):
inner_loss = torch.norm(A@inner_params - tail_out.detach())
inner_loss.backward() #this should only affect inner params
inner_optim.step()
inner_optim.zero_grad()
inner_params.detach_() #outer graph doesn't need to track these
out = tail_out + inner_params * self.head(tail_out)
return out
## forward/ backward pass of outer model
outer_model = OuterModel()
outer_optim = torch.optim.AdamW(outer_model.parameters())
for x,y in loader:
x,y = data_batch
out = outer_model(x)
outer_loss = (out-y)
outer_loss.backward() #only does backward for outer params
outer_optim.step()
outer_optim.zero_grad()
The easiest way is to do this manually: inner_grad = torch.autograd.grad(inner_loss, inner_params)
and then inner_params -= 0.1*inner_grad
but ideally we want to leverage all the nice optimizers pytorch offers out of the box and and we don’t want to concatenate all our inner_params into one tensor. I’ve also tried to play with inner_loss.backward(retain_graph=True)
to retain the outer graph, but this doesn’t train (I suspect because the inner graph is also retained or something)
Thanks for any pointers!
Paul