AssertionError: Gather function not implemented for CPU tensors

Hi,

I have been trying to use nn.DataParallel for my model. But I keep getting above error. I have not been able to produce the error in any other settings. My model is as follows:

class SSS(torch.nn.Module):


  def __init__(self):
      super(SSS, self).__init__()
      self.effnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True)
      self.features = nn.Sequential(*list(self.effnet.children())[:-1])
      self.AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d(1)
      self.fc1 = nn.Linear(in_features = 1280, out_features = 1024)
      self.fc2 = nn.Linear(in_features = 1280, out_features = 1024)

  def forward(self, x):
      features = self.features(x)
      
      top_left = features[:,:,:int(features.shape[2]/2),:int(features.shape[3]/2)]
      top_right = features[:,:,:int(features.shape[2]/2),int(features.shape[3]/2):features.shape[3]]
      bottom_left = features[:,:,int(features.shape[2]/2):features.shape[2],:int(features.shape[3]/2)]
      bottom_right = features[:,:,int(features.shape[2]/2):features.shape[2],int(features.shape[3]/2):features.shape[3]]
      
      top_left = self.AdaptiveAvgPool2d(top_left)
      top_left = torch.squeeze(top_left)
      top_left = self.fc1(top_left)
      
      top_right = self.AdaptiveAvgPool2d(top_right)
      top_right = torch.squeeze(top_right)
      top_right = self.fc1(top_right)
      
      bottom_left = self.AdaptiveAvgPool2d(bottom_left)
      bottom_left = torch.squeeze(bottom_left)
      bottom_left = self.fc1(bottom_left)
      
      bottom_right = self.AdaptiveAvgPool2d(bottom_right)
      bottom_right = torch.squeeze(bottom_right)
      bottom_right = self.fc1(bottom_right)
      
      features = self.AdaptiveAvgPool2d(features)
      features = torch.squeeze(features)
      features = self.fc2(features)
      
      output = torch.empty((x.shape[0],5,1024))
      
      output[:,0] = top_left
      output[:,1] = top_right
      output[:,2] = bottom_left
      output[:,3] = bottom_right
      output[:,4] = features
      
      return output

The data parallel setting is also written as below:

  os.environ['CUDA_VISIBLE_DEVICES'] = "0,1" 
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
  netx5 = SSS()
  netx10 = SSS()
  netx20 = SSS()
 
  netx5 = nn.DataParallel(netx5)
  netx10 = nn.DataParallel(netx10)
  netx20 = nn.DataParallel(netx20)
  
  netx5.to(device)
  netx10.to(device)
  netx20.to(device)
  
  criterion = nn.MSELoss()
  lr = 0.001
  optimizer = torch.optim.Adam(list(netx5.parameters()) + list(netx10.parameters()) +  list(netx20.parameters()) , lr=lr)

the train function is also written as below:

  def train(train_loader, netx5, netx10, netx20, optimizer, criterion):
      main_loss = 0
      for data in train_loader:
          x5 = data[0].to(device)
          x10 = data[1].permute(1,0,2,3,4).to(device)
          x20 = data[2].permute(1,2,0,3,4,5).to(device)
          
          x5_features = netx5(x5)
          
          x10_features = torch.empty((4,x5.shape[0],5,1024))
          x20_features = torch.empty((4,4,x5.shape[0],5,1024))
          for i, crop in enumerate(x10):
              # print(crop.shape)
              x10_features[i] = netx10(crop)
              
          for i, crop in enumerate(x20):
              for j , subcrop in enumerate(crop):
                  x20_features[i,j] = netx20(subcrop)
                  
          logits = torch.empty((x5.shape[0],22,1024))
          labels = torch.empty((x5.shape[0],22,1024))
          
          for i in range(4):
              logits[:,i] = x5_features[:,i]
              labels[:,i] = x10_features[i,:,4]
              
          for i in range(4):
              for j in range(4):
                  logits[:, i*4 + j + 4] = x10_features[i,:,j]
                  labels[:, i*4 + j + 4] = x20_features[i,j,:,4]
                  
          
          logits[:,20] = x5_features[:,4]
          labels[:,20] = torch.mean(x10_features.permute(1,0,2,3)[:,:,4], dim = 1)
          
          # print(x10_features.permute(1,0,2,3)[:,:,4].shape)
          # print(torch.sum(x20_features.permute(2,0,1,3,4)[:,:,:,4], dim = 1).shape)
          
          logits[:,21] = torch.mean(x20_features.permute(2,0,1,3,4)[:,:,:,4].reshape(x5.shape[0],16,1024), dim = 1)
          labels[:,21] = x5_features[:,4]
          
             
          optimizer.zero_grad()
  
          
          loss = criterion(logits,labels)
          loss.backward()
          
          optimizer.step()
  
          
          main_loss += loss.item()
          
      normalized_loss = main_loss/x5.shape[0]
      # scheduler.step() 
          # loss_graph.append() Populate this list to graph the loss
      loss_graph.append(normalized_loss)    
      return normalized_loss

The complete error:

Traceback (most recent call last):
  File "pretex.py", line 303, in <module>
    loss = train(train_loader, netx5, netx10, netx20, optimizer, criterion)
  File "pretex.py", line 242, in train
    x5_features = netx5(x5)
  File "/home/pashrafi/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pashrafi/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward
    return self.gather(outputs, self.output_device)
  File "/home/pashrafi/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 181, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/pashrafi/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 78, in gather
    res = gather_map(outputs)
  File "/home/pashrafi/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/home/pashrafi/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 57, in forward
    'Gather function not implemented for CPU tensors'
AssertionError: Gather function not implemented for CPU tensors

I have tried passing .to(device) before and after nn.DataParallel. I would appreciate if you could help me.