I am currently testing a custom layer where I am taking a regular conv layer, but I am also introducing an intermediate loss where I get the MSEloss between the image patch and the filter weights that caused the highest output per image. My current implementation calculates the full loss in the custom layer itself and uses self.register_parameter('unsupervised_loss',nn.Parameter(torch.zeros(1)) )
to store my intermediate loss.
While this works in the network by itself, it only returns the original value of 0 when I wrap it with dataparallel. This problem is similar to [DataParallel] Get updated module attributes which does not have an answer, but I also want to specify that I need the gradients to also work. In The original documentation there is a warning about this behavior, stating that you can use in place updates on device 0, but how would I extend this to my specific issue where the operations are input-specific and not like incrementing a counter on device 0.
I see two possible options that I might test: either I send the unsupervised losses to device 0 somehow, or I instead rewrite the function to return both the conv output and the unsupervised loss as a tuple. Both of these options seem messy though.