Is there any way to compute second order grad for DataParallel models?

Hi, I’m currently working on a project related to paralleled MAML training. I need to compute second-order grad for DataParallel models. However, the only way I find to solve this problem is
torch.autograd.grad(loss, model.module.params())
(params is a generator function to output all parameters, which is implemented by myself.), which may cause imbalanced GPU utilization. Is there any way to solve this problem?

Hi,

I would say this is expected as the gradients are all gathered back to the main GPU after being computed (so that you can do the weight update).

HI, thanks for your reply. In my project, the imbalanced GPU utilization is quite serious. I want to optimize my model on 4 GPUs at the same time. Is there any way to solve this problem? Thank you.