I have a loss function that uses the intermediate output at various points in a network. Currently, I just a dummy class that holds the intermediate results, which I can inject at the points of the network I need to:
class Tracker(nn.Module): def __init__(self): super(Tracker, self).__init__() def forward(self, x): self.x = x return x
This solution works fine for a single GPU, but it breaks when using data_parallel.
Any suggestions? Thanks.