#meta learning# #gradient computation error# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I’m training the DL model under the MAML framework. But there is an error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 1]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The simple version of MAML framework is:

weights_original = deepcopy(model.state_dict())
output_support = model(input_support)
optimizer.zero_grad()
loss = criterion(output, gt_support)
loss.backward()
optimizer.step()

optimizer.zero_grad()
output_query = model(input_query)
loss = criterion(output, gt_query)
# to calculate gradients w.r.t. the initial params and update the initial params.
model.load_state_dict({name: weights_original[name] for name in weights_original})
loss.backward()
optimizer.step()

And for easier debugging, I use a simple model as:

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.transformer_encoder = torch.nn.Linear(512, 512)
        self.transformer_decoder = torch.nn.Linear(512, 512)
    def forward(self, input):
        output = self.transformer_encoder(input)
        output = self.transformer_decoder(output)
        return output

When I dive into the bug, finding that if the model consists of one module, e.g. the encoder or decoder, the error disappears. But if the model include both encoder and decoder, the bug re-appears. Does anyone know the solution? THANK YOU!

Loading the state_dict will manipulate the parameters inplace and will cause the failure:

model.load_state_dict({name: weights_original[name] for name in weights_original})
loss.backward()

Changing the parameters before calling the backward operation will make the forward activations stale and would otherwise compute wrong gradients.

This post explains the issue in more detail for a GAN training, which has however the same root cause of creating stale forward activations.

Thanks for your reply. Actually, I need the model.load_state_dict operation to calculate gradients w.r.t. the initial parameters due to the mechanism of MAML.
Yes, in fact when I remove the model.load_state_dict it works fine. But if the model has one module, such as the encoder defined above, it also works fine when calling model.load_state_dict. And there is one public source code (line 156) that has the similar operation. When I unwrap the code in line 150, it basically has the same pipeline as the above MAML framework.