Hi all,
I am raising this issue with the following example:
for inner_step in range(self.inner_steps):
d_losses, _, bert_outputs = model.forward_with_params(
inputs_embeds=self.inputs_embeds,
labels=self.labels,
weights=weights,
output_attentions=True,
)
grads = torch.autograd.grad(
d_loss, weights.values(), create_graph=True, allow_unused=True
)
weights = OrderedDict(
(name, param - grad) if grad is not None else (name, param)
for ((name, param), grad) in zip(weights.items(), grads)
)
Now, when I am watching the GPU consumption, I have noticed a huge increase in such a consumption. I have tried to make the updates of the weights to be no_grad(), but it did not work as well.
Your help is appreciated!