An increase in GPU consumption in the gradients

Hi all,

I am raising this issue with the following example:

for inner_step in range(self.inner_steps):
  d_losses, _, bert_outputs = model.forward_with_params(
                          inputs_embeds=self.inputs_embeds,
                          labels=self.labels,
                          weights=weights,
                          output_attentions=True,
                      )
  grads = torch.autograd.grad(
                          d_loss, weights.values(), create_graph=True, allow_unused=True
                      )
   weights = OrderedDict(
                          (name, param - grad) if grad is not None else (name, param)
                          for ((name, param), grad) in zip(weights.items(), grads)
                      )

Now, when I am watching the GPU consumption, I have noticed a huge increase in such a consumption. I have tried to make the updates of the weights to be no_grad(), but it did not work as well.
Your help is appreciated!

Double post from here.