Using intermediate results for loss (Data Parallel)

I have a loss function that uses the intermediate output at various points in a network. Currently, I just a dummy class that holds the intermediate results, which I can inject at the points of the network I need to:

class Tracker(nn.Module):
    def __init__(self):
        super(Tracker, self).__init__()

    def forward(self, x):
        self.x = x
        return x

This solution works fine for a single GPU, but it breaks when using data_parallel.

Any suggestions? Thanks.

Make these intermediate results outputs.

I ended up using list indexed by the device id of the tensor. This way I did not have to modify the structure of the network at all. Thanks.