Unable to run Dataparallel for Capsule Network

I am trying to run model which is created using Capsule Network on multiple GPUs by using pytorch’s dataparallel approach but getting error. Below is the stacktrace and code for the same.
Not understanding why its not replicating the model across all the GPUs. Please suggest me where I am going wrong.

cuda
Let's use 4 GPUs!
Devices: [0, 1, 2, 3]
Traceback (most recent call last):
  File "main.py", line 245, in <module>
    train(epoch)
  File "main.py", line 161, in train
    output = model(data) # forward.
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/engn8536/project_data/u6724013/caer/capsule_network.py", line 43, in forward
    h = self.primary_caps(h)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/engn8536/project_data/u6724013/caer/primary_caps.py", line 56, in forward
    u_i = self.conv_units[i](x)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/engn8536/project_data/u6724013/caer/primary_caps.py", line 26, in forward
    h = self.conv(x)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 343, in forward
    return self.conv2d_forward(input, self.weight)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 340, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

main.py

parser.add_argument('--gpu', help="GPU_ID", type=str, default = "0,1,2,3")
args = parser.parse_args()

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:"+re.split(r",",args.gpu)[0] if USE_CUDA else "cpu")
print(device)
gpu_id = list(map(int, re.split(r",",args.gpu)))

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  model = nn.DataParallel(model,device_ids = gpu_id).to(device)
  print('Devices:', model.device_ids)
  Use_Dataparallel = True
u_i = self.conv_units[i](x)

Is this a Pytorch list or a python list?
It seems whether some layers aren’t being sent to gpu or you aren’t properly sending the input. Can you post the model and how are you allocating inputs?

I am actually using Capsule Network and creating stack of conv units to form Capsule Layers.
Below is the code for the same.

class ConvUnit(nn.Module):
        def __init__(self, in_channels, out_channels):
                super(ConvUnit, self).__init__()

                self.conv = nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=9,
                        stride=2,
                        bias=True
                )

        def forward(self, x):
                # x: [batch_size, in_channels=256, 20, 20]

                h = self.conv(x)
                # h: [batch_size, out_channels=8, 6, 6]

                return h


class PrimaryCaps(nn.Module):
        def __init__(self):
                super(PrimaryCaps, self).__init__()

                self.conv1_out = 256 
                self.capsule_units = 32
                self.capsule_size = 8

                def create_conv_unit(unit_idx):
                                unit = ConvUnit(
                                        in_channels=self.conv1_out,
                                        out_channels=self.capsule_size
                                )
                                self.add_module("unit_" + str(unit_idx), unit)
                                return unit

                self.conv_units = [create_conv_unit(i) for i in range(self.capsule_units)]

        def forward(self, x):
                # x: [batch_size, 256, 20, 20]
                batch_size = x.size(0)

                u = []
                for i in range(self.capsule_units):
                        u_i = self.conv_units[i](x)
                        # u_i: [batch_size, capsule_size=8, 6, 6]

                        u_i = u_i.view(batch_size, self.capsule_size, -1, 1)
                        # u_i: [batch_size, capsule_size=8, 36, 1]

                        u.append(u_i)
                # u: [batch_size, capsule_size=8, 36, 1] x capsule_units=32

                u = torch.cat(u, dim=3)

                u = u.view(batch_size, self.capsule_size, -1)
                # u: [batch_size, capsule_size=8, 1152=36*32]

                u = u.transpose(1, 2)
                # u: [batch_size, 1152, capsule_size=8]

                u_squashed = squash(u, dim=2)
                # u_squashed: [batch_size, 1152, capsule_size=8]

                return u_squashed

For model(capsule_network)

class CapsuleNetwork(nn.Module):
        def __init__(self,gpu,routing_iters=3, reconstruct=True):
                super(CapsuleNetwork, self).__init__()

                self.gpu = gpu
                self.has_reconstruction = reconstruct

                # Build modules for CapsNet.

                ## Convolution layer
                self.conv1 = Conv1()

                ## PrimaryCaps layer
                self.primary_caps = PrimaryCaps()

                ## EmotionCaps layer
                self.emotion_caps = EmotionCaps(routing_iters=routing_iters, gpu=gpu)

                ## Decoder for reconstruction
                if reconstruct:
                        self.decoder = Decoder()

                def forward(self, x):
                # x: [bacch_size, 1, 28, 28]

                h = self.conv1(x)
                # h: [batch_size, 256, 20, 20]

                h = self.primary_caps(h)
                # h: [batch_size, 1152=primary_capsules, 8=primary_capsule_size]

                h = self.emotion_caps(h)
                # h: [batch_size, 10=digit_capsule, 16=digit_capsule_size]

                return h

        def loss(self, images, input, target, size_average=True):
                # images: [batch_size, 1, 28, 28]
                # input: [batch_size, 10, 16, 1]
                # target: [batch_size, 10]

                margin_loss = self.margin_loss(input, target, size_average)

                if self.has_reconstruction:
                        reconstruction_loss = self.reconstruction_loss(images, input, size_average)
                else:
                        reconstruction_loss = Variable(torch.zeros(1))
                        #if self.gpu >= 0:
                        reconstruction_loss = reconstruction_loss.to(device)

                loss = margin_loss + reconstruction_loss

                return loss, margin_loss, reconstruction_loss

       def margin_loss(self, input, target, size_average=True):
                # images: [batch_size, 1, 28, 28]
                # input: [batch_size, 10, 16]
                # target: [batch_size, 10]

                batch_size = input.size(0)

                # ||vc|| from the paper.
                v_mag = torch.sqrt((input**2).sum(dim=2, keepdim=True))
                # v_mag: [batch_size, 10, 1]

                # Calculate left and right max() terms from Eq.4 in the paper.
                zero = Variable(torch.zeros(1))
                #if self.gpu >= 0:
                zero = zero.cuda()
                m_plus = 0.9
                m_minus = 0.1
                max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1)**2
                max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1)**2
                # max_l, max_r: [batch_size, 10]

                # This is Eq.4 from the paper.
                loss_lambda = 0.5
                T_c = target
                # T_c: [batch_size, 10]
                L_c = T_c * max_l + loss_lambda * (1.0 - T_c) * max_r
                # L_c: [batch_size, 10]
                L_c = L_c.sum(dim=1)
                # L_c: [batch_size]

                if size_average:
                      L_c = L_c.mean() # average over batch.
                else:
                        L_c = L_c.sum() # sum over batch.

                return L_c

        def reconstruction_loss(self, images, input, size_average=True):
                # images: [batch_size, 1, 28, 28]
                # input: [batch_size, 10, 16]

                batch_size = images.size(0)

                # Reconstruct input image.
                reconstructed = self.reconstruct(input)
                # reconstructed: [batch_size, 1, 28, 28]

                # The reconstruction loss is the sum squared difference between the input image and reconstructed image.
                # Multiplied by a small number so it doesn't dominate the margin (class) loss.
                error = (reconstructed - images).view(batch_size, -1)
                error = error**2
                # error: [batch_size, 784=1*28*28]
                error = torch.sum(error, dim=1)
                # error: [batch_size]

                if size_average:
                        error = error.mean() # average over batch.
                else:
                        error = error.sum() # sum over batch.

                rec_loss_weight = 0.0005
                error *= rec_loss_weight
 
                return error

        def reconstruct(self, input):
                # input: [batch_size, 10, 16]

                assert self.has_reconstruction, 'Reconstruction path is disabled. For the reconstruction, configure `reconstruct=True` of CapsuleNetwork.'

                # Get the lengths of capsule outputs.
                v_mag = torch.sqrt((input**2).sum(dim=2))
                # v_mag: [batch_size, 10]

                # Get index of longest capsule output.
                _, v_max_index = v_mag.max(dim=1)
                v_max_index = v_max_index.data
                # v_max_index: [batch_size]

                # Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
                batch_size = input.size(0)
                all_masked = [None] * batch_size
                for batch_idx in range(batch_size):
                        # Get one sample from the batch.
                        input_batch = input[batch_idx]
                        # input_bacth: [10, 16]

                        # Copy only the maximum capsule index from this batch sample.
                        # This masks out (leaves as zero) the other capsules in this sample.
                        batch_masked = Variable(torch.zeros(input_batch.size()))
                        #if self.gpu >= 0:
                        batch_masked = batch_masked.cuda()
                        batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
                        # batch_masked: [10, 16]

                        all_masked[batch_idx] = batch_masked
                # all_masked: [10, 16] * batch_size

                # Stack masked capsules over the batch dimension.
                masked = torch.stack(all_masked, dim=0)
                # masked: [batch_size, 10, 16]
                masked = masked.view(batch_size, -1)
                # masked: [batch_size, 160]

                # Reconstruct input image.
                reconstructed = self.decoder(masked)
                # reconstructed: [batch_size, 1, 28, 28]

                return reconstructed

Hi,
I would say the problem is you are dropping layers in a normal python list.
Can you use https://pytorch.org/docs/stable/nn.html?highlight=modulelist#torch.nn.ModuleList
nn.ModuleList instead?
A nn.Module is not aware of layers in a python list

Thanks…it worked :slight_smile: