DataParallel with custom layer parameters

I am currently testing a custom layer where I am taking a regular conv layer, but I am also introducing an intermediate loss where I get the MSEloss between the image patch and the filter weights that caused the highest output per image. My current implementation calculates the full loss in the custom layer itself and uses self.register_parameter('unsupervised_loss',nn.Parameter(torch.zeros(1)) ) to store my intermediate loss.

While this works in the network by itself, it only returns the original value of 0 when I wrap it with dataparallel. This problem is similar to [DataParallel] Get updated module attributes which does not have an answer, but I also want to specify that I need the gradients to also work. In The original documentation there is a warning about this behavior, stating that you can use in place updates on device 0, but how would I extend this to my specific issue where the operations are input-specific and not like incrementing a counter on device 0.

I see two possible options that I might test: either I send the unsupervised losses to device 0 somehow, or I instead rewrite the function to return both the conv output and the unsupervised loss as a tuple. Both of these options seem messy though.

I believe this is because DataParallel (DP) would replicate the module to each provided device in every forward. So the module instance stored in DataParallel and the module instance used by forward on every device are different.

However, even if they are the same instance, it’s still problematic. Suppose we run DP on two devices, then, one of the device would override the value of unsupervised_loss written by another device. Is this an expected behavior?

I instead rewrite the function to return both the conv output and the unsupervised loss as a tuple.

This seems a reasonable workaround.

I ended up having to do the tuple output to get things working. Unfortunately this messes up my original implementation of the network since my custom layer was being used in a sequential. Had to rewrite some code to fix that.