I am raising this issue with the following example:
for inner_step in range(self.inner_steps): d_losses, _, bert_outputs = model.forward_with_params( inputs_embeds=self.inputs_embeds, labels=self.labels, weights=weights, output_attentions=True, ) grads = torch.autograd.grad( d_loss, weights.values(), create_graph=True, allow_unused=True ) weights = OrderedDict( (name, param - grad) if grad is not None else (name, param) for ((name, param), grad) in zip(weights.items(), grads) )
Now, when I am watching the GPU consumption, I have noticed a huge increase in such a consumption. I have tried to make the updates of the weights to be no_grad(), but it did not work as well.
Your help is appreciated!