Is it possible to access the model replicas in `torch.nn.DataParallel()` between `forward()` and `.backward()` call?

I am trying to design backward hooks that depend on the loss.
In the single GPU version, i modify the hook between the forward and backward pass depending on the loss value.
I would like to scale to multi-gpu and modify each replicated hook depending on the loss of each GPU.
To do so, Is it possible to access the model replicas in torch.nn.DataParallel() between forward() and .backward() call?
Thanks for helping :slightly_smiling_face:

nn.DataParallel does not expose replicas, but you could make some changes to it locally or copy the code to make replicas a member field (this line)

It might be easier to install hooks with DistributedDataParallel (DDP) though, as DDP is not making new copies of models in every iteration. So that whatever hooks installed to the original module should still be valid with DDP.

I’ll take a look at DistributedDataParallel first, it looks the cleanest thing to do.
Thanks for your help @mrshenli !

1 Like