Handling intermediate/latent tensors when using nn.DataParallel

Hi all,

I’ve been having some trouble with a U-Net style implementation in a multi-GPU setup. I’ve mostly traced the central problem to be how latent tensors across skip connections (saving layers in the down convolution pathway and concatenating them into the up convolution pathway). The network performs fine in a single GPU setup.

The issue is not with the nn.DataParallel setup, as I tested the same data on a simple model with only convolution layers, the issue comes up when trying to save/concat intermediate tensors during the forward() pass. This isn’t an issue with the actual convolution layers it seems, as the down pathway work, and when I turn off any latent/skip connections the upwards works too. It’s actually kinda of confusing, because I also tried scoping these tensors to the class self method, but that also didn’t work, but yet “x” is completely free and is handled fine… which I don’t get why that is exactly.

Is there a workaround or fix? Perhaps some combination of proper scoping will work? Any help would be greatly appreciated, as being able to train on multiple GPUs would be great given we have them :slightly_smiling_face:

Thanks!
Andy


Here is the relevant code:

if self.first_time: print (" +++ Building Encoding Pathway +++ ")
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            if self.first_time: print ("Layer {0} :".format(i) + str(x.shape))
            x, self.before_pool = module(x)
            self.encoder_outs.append(self.before_pool)

        if self.first_time: print ("Bottlenecking Skip Layers")
        for i, skip_layer in enumerate(self.encoder_outs):
**L323**            skip_layer = self.skip_conv_ins[i](skip_layer)
            if self.first_time: print ("Layer {0}: ".format(i) + str(skip_layer.shape))
            skip_layer = self.skip_conv_outs[i](skip_layer)
            self.encoder_outs[i] = skip_layer

############################################
L66 class BottleConv(nn.Module):
    """
    Collapses or expands (w.r.t to channels) latent layers.  Composed of a single conv1x1 layer
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(BottleConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling
        self.conv1 = conv1x1(self.in_channels, self.out_channels)
        self.bn = nn.BatchNorm2d(self.out_channels)

    def forward(self, x):
**L80**        x = self.bn(F.relu(self.conv1(x)))
        return x

Here is the trace back, I’;ve bolded the lines which correspond to the above code

Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  **File "/home/andrew/Desktop/deep-learning-ihc/Encoders.py", line 323, in forward**
    skip_layer = self.skip_conv_ins[i](skip_layer)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  **File "/home/andrew/Desktop/deep-learning-ihc/Encoders.py", line 80, in forward**
    x = self.bn(F.relu(self.conv1(x)))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 343, in forward
    return self.conv2d_forward(input, self.weight)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 340, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

I’m not sure I understand the code completely, but it looks like you are storing some activations in self.encoder_outs. Are these activations for debugging only or is this list part of the model?
In the latter case, how did you define self.encoder_outs and how would you like to reduce it?
I guess each replica will initialize its own self.encoder_outs, so that you end up having N lists.