There is an increasing amount of research (meta-learning, deep equilibrium models, OptNets etc.) where the forward pass of some model includes solving a seperate optimization problem.
In many cases, we don’t need to track this inner optimization problem as part of the same graph as our main model (e.g. because we derive gradients manually and add them as a hook).
The question is this: how can I use an entire sub inner graph that doesn’t affect or know about the outer graph it is part of. Here is some pseudo code:
class OuterModel(nn.Module): def __init__(self): super(OuterModel, self).__init__() self.tail = nn.Sequential(...) self.head = nn.Sequential(...) def forward(self, x): tail_out = self.tail(x) #tracked in outer graph with torch.no_grad(): #everyhting happening here should be tracked in some inner graph only inner_params = torch.zeros(requires_grad=True) inner_optim = torch.optim.AdamW([inner_params]) for i in range(inner_steps): inner_loss = torch.norm(A@inner_params - tail_out.detach()) inner_loss.backward() #this should only affect inner params inner_optim.step() inner_optim.zero_grad() inner_params.detach_() #outer graph doesn't need to track these out = tail_out + inner_params * self.head(tail_out) return out ## forward/ backward pass of outer model outer_model = OuterModel() outer_optim = torch.optim.AdamW(outer_model.parameters()) for x,y in loader: x,y = data_batch out = outer_model(x) outer_loss = (out-y) outer_loss.backward() #only does backward for outer params outer_optim.step() outer_optim.zero_grad()
The easiest way is to do this manually:
inner_grad = torch.autograd.grad(inner_loss, inner_params) and then
inner_params -= 0.1*inner_grad but ideally we want to leverage all the nice optimizers pytorch offers out of the box and and we don’t want to concatenate all our inner_params into one tensor. I’ve also tried to play with
inner_loss.backward(retain_graph=True) to retain the outer graph, but this doesn’t train (I suspect because the inner graph is also retained or something)
Thanks for any pointers!