Given a model with weights w(t) during a trainstep t and the next step w(t+1), what is the most efficient way to return the model weights back to w(t) from w(t+1)?
After going back to w(t) another training step follows to a new and different w(t+1), where second order derivatives should be possible through the first w(t+1) and w(t).
I am trying to implement the following code, but in PyTorch.
EDIT:
Currently my solution is:
optim = torch.optim.Adam(self.model.parameters(), lr=ARGS.lr)
losses = []
for inner_x, outer_x, inner_y, outer_y in SomeDataLoader:
gradss = []
for j in range(iterations):
self.model.zero_grad()
inner_output = self.model(inner_x)
inner_loss = self.criterion(inner_output, inner_y)
grads = torch.autograd.grad(inner_loss, self.model.parameters(), create_graph=True)
gradss.append(grads)
with torch.no_grad():
for i, param in enumerate(self.model.parameters()):
param -= self.inner_lr * grads[i]
outer_output = self.model(outer_x)
outer_loss = self.criterion(outer_output, outer_y)
with torch.no_grad():
for j in range(iterations):
for i, param in enumerate(self.model.parameters()):
param.data.copy_(param.data.clone().detach() + self.inner_lr * gradss[j][i].detach())
losses.append(outer_loss)
loss = torch.stack(losses).mean()
loss.backward()
optim.step()
However, I get the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation