Dimension problem by multiple GPUs

Here is the situation. A customized DataLoader is used to load the train/val/test data.
The model can be launched on single GPU, but not multiples.

class EncoderDecoder(torch.nn.Module):
        def forward(feats, masks,...)
	    clip_masks = self.clip_feature(masks, feats)
	    ....
			
	def clip_feature(self, masks, feats):
         '''
         This function clips input features to pad as same dim.
         '''
            max_len = masks.data.long().sum(1).max()
            print('max_len:%d' % max_len)
            masks = masks[:, :max_len].contiguous()
            ....

            return masks
	    ......
			
def train(opt):
	model = EncoderDecoder(opt)
       
        # setting-1
	cuda_model = model.cuda().train()
 
       # setting-2
       # cuda_model = torch.nn.DataParallel(model.cuda())

	cuda_model.train()
        torch.cuda.synchronize()
        ...

If I launch the model on single GPU as marked as “setting-1”, it works but lasts days. The corresponding returned tensors in clip_features is as expected. The debug info is given as follows:

masks.shape (150, 61)
EncoderDecoder clip_feature masks.shape in (150, 61)
masks.device:cuda:0
max_len:61
masks.shape clip_att (150, 61)
max_len:61
masks.size (150, 61)
att_mask.device cuda:0

Instead of running on single gpu, I use DataParallel, indicated as “setting-2”, the results are changed somehow,

EncoderDecoder clip_feature masks.shape in (38, 61)
masks.device:cuda:0
masks.shape (38, 61)
EncoderDecoder clip_feature masks.shape in (38, 61)
masks.device:cuda:1
masks.shape (38, 61)
RelationTransformer clip_feature att_masks.shape in (38, 61)
masks.device:cuda:2
max_len:50
max_len:50

It posts the runtime error later for multiplication I intend to have:

RuntimeError: The size of tensor a (61) must match the size of tensor b (60) at non-singleton dimension 3

I have no idea how it happens. The batched input dispatched on different devices, but the results are totally different with the one returned by a single GUP. I do not think that it depends on the parallel dispatching on GPUs. Maybe I missed some configurations for my model. The running environment looks like as follows(I tested it with different torch visions):

  • torch 0.4.1 / 1.4.0+cu100
  • torchvision 0.2.1/ 0.5.0+cu100
  • 4 x Tesla V100-SXM2 Driver Version: 410.104 CUDA Version: 10.0

Hoping any inputs to help me out. Thanks.

Which line of code is throwing this error?
Could you add the device information to the max_len print, as I’m not sure where the 50 is coming from, since the masks are cropped to 61.

ptrblck, thanks for your inputs.

Here again the debug info before and after calling clip_feature function running on multiple GPUs.

EncoderDecoder clip_feature masks.shape in (38, 61)
masks.device:cuda:0
EncoderDecoder clip_feature masks.shape in (38, 61)
masks.device:cuda:1
EncoderDecoder clip_feature masks.shape in (38, 61)
masks.device:cuda:2
max_len:50      max_len.device: cuda:0
EncoderDecoder clip_feature masks.shape in (36, 61)
max_len:54       max_len.device: cuda:1
max_len:61      max_len.device: cuda:2
masks.device:cuda:3
max_len:51      max_len.device: cuda:3

It’s actually a big jump to show where the error comes since there are lots of operations before calling the relation_geo_attention function. Since the whole snippet works on single GPU, I guess it’s necessary to poll all the codes out here (hopefully I’m right:D). The trigger is the dimension after clip_feature is totally wrong away from the expectation (as one output by single GPU). The error comes from lines marked with ^^^^s.

def relation_geo_attention(query, key, value, box_embd_matrix, mask=None):
    N = value.size()[:2]
    dim_k = key.size(-1)
    dim_g = box_embd_matrix.size()[-1]

    w_q = query
    w_k = key.transpose(-2, -1)
    w_v = value
    w_g = box_embd_matrix
	
    #attention weights
    scaled_dot = torch.matmul(w_q,w_k)
    w_a = scaled_dot / np.sqrt(dim_k)
    if mask is not None:
        w_a = w_a.masked_fill(mask == 0, -1e9)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        #! RuntimeError occurs from here

    # calculating  retlation between geometric and feature
    w_mn = torch.log(torch.clamp(w_g, min = 1e-6)) + w_a
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #! RuntimeError occurs possibly here either

After clipping features, the returned dimension is not as expected, all the following operations in transformer (the relation geometric attention function) are wrong. I do not think it’s a scatter problem from devices.

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
....
....
w_a = w_a.masked_fill(mask == 0, -1e9)
RuntimeError: The size of tensor a (61) must match the size of tensor b (50) at non-singleton dimension 3

What really confuses me is why it works on single device, but got errors instead of running it on n-GPUs.

Thanks
Anakin

If the code runs on single GPU, all the data is in a big batch. All the batch data is fed to 8 head attention model to calculate the relation between geometric and appearance separately. Thus all the dimensions are perfect aligned. If code got launched on multiple GPUs, each batch has it own data. This results of course the different dimensions corresponding to the inputs. That’s the problem, as I understand,if it’s correct. Still looking for a solution.

Could you add an assert statement and check, that all tensors in your relation_geo_attention funtion are on the same device?

nn.DataParallel will split the data tensors in dim0 and send each chunk to the corresponding device.
I.e. if the first chunk has a batch_size of 51, all other tensors passed to the forward method will have the same batch size.

Also, make sure to use the forward method, as nn.DataParallel uses this method to split the data.
If you are using a custom function as mode.my_fun(data), you would have to take care of the splitting yourself.

Thanks for your reply again, ptrblck.
I think probably you got confused by my description. As so far, the data is correctly split in the chunks

[38 x 61] [38 x 61] [38 x 61] [36 x 61]

The input length arises the error. Concretely, I have a relation model that calculates relation between geometric and appearance. This relation model in organized in a ModuleList in cascade fashion. The attentions are fed to the first module in ModuleList. Then its output will be fed to following modules further.
If the data just scattered on single device, all the relation models have the aligned (same) dimension. It works. If the data got chunked and scattered on 4 devices, the features have different lengths on single device, i.e, device 0 as debug info suggests below.

     if mask is not None:
            print('mask.size', mask.size())
            print('mask.device:\t%s' % mask.device) 
            print('w_a.size', w_a.size())
            print('w_a.device:\t%s' % w_a.device)

            assert query.device == key.device, 'query and key are not on the same device'
            assert value.device == key.device, 'value and key are not on the same device'
            assert query.device == box_relation_embds_matrix.device, 'query and box are not on the same device'
            assert query.device == mask.device, 'query and mask are not on the same device'
            assert w_a.device == mask.device, 'w_a and mask are not on the same device'

            w_a = w_a.masked_fill(mask == 0, -1e9)

All the assertions hold true. Here the corresponding debug outputs(just for simplicity and convenience, i used two devices):

###clip_feature
EncoderDecoder clip_feature masks.shape in torch.Size([75, 61])
masks.device:cuda:0
EncoderDecoder clip_feature masks.shape in torch.Size([75, 61])
masks.device:cuda:1
max_len:54      max_len.device: cuda:0
max_len:61      max_len.device: cuda:1

### got padded somewhere later
padded.size torch.Size([75, 61, 512])
padded.size torch.Size([75, 54, 512])

### the first relation model info
 mask.size torch.Size([75, 1, 1, 61])
 mask.device:       cuda:1
 w_a.size torch.Size([75, 8, 61, 61])
 w_a.device:        cuda:1

### the second relation model info
 mask.size torch.Size([75, 1, 1, 61])
                       ^^^^^^^^^^^^^^
 mask.device:       cuda:0
 w_a.size torch.Size([75, 8, 54, 54])
                       ^^^^^^^^^^^^^^
 w_a.device:        cuda:0

As in last reply shows, the problem comes from the device 0, showing as follow:

For comparison purpose, I post the log of single device:

# clip_feature
EncoderDecoder clip_feature masks.shape in torch.Size([150, 61])
masks.device:cuda:0
max_len:61      max_len.device: cuda:0
padded.size torch.Size([150, 61, 512])
### the first relation model info
mask.size torch.Size([150, 1, 1, 61])
mask.device:       cuda:0
w_a.size torch.Size([150, 8, 61, 61])
w_a.device:        cuda:0

### the same for other relation models

It seems like a coding problem rather than a scattering problem, as I guess as so far. I’ll dig it out continuously. Any inputs will be thankful.

Yeah, I see that the batch dimension seems to be alright in your first first output ([38, 61] ...).
However, why is the mask size [75, 61]? If the batch chunks have a batch size of 38 or 36, I’m not sure why your masks have suddenly a non-matching shape.

As I stated here ‘(just for simplicity and convenience, i used two devices):’ I tested it just on two devices. That’s why it has [75, 61]. The four batches are split for 4 devices.
Thanks for your reply ptrblck

Hello,
I am having similar issue. My code runs good on single GPUs and multiple GPUs on a server. Recently I made my own setup of multiple GPU. But I got this issue. Have you found the solution @Anakin?

@Anakin @abhidipbhattacharyya I have met a similar issue, when implementing multi-head attention, I need to use the function troch.masked_fill(), and the code runs fine on single gpu but got an dimension mismatch error on multi-gpus. In my situation, I found that the first dimension of the mask I pass to the forward method is not batch_size, but DataParallel will split the tensor on dimension 0 by default, causing the mismatch problem. So what we need to do is adding a batch_size dimention to the mask, using mask=mask.unsqueeze(0) before you pass the mask to the model, and after split, all device will get the same mask because the batch_size is 1 here. Also you will need to remember your mask already has a batch_size dimension and don’t need to add this dimension in self-attention logic.

To sum up, ensure all tensors you pass to the DataParallel module have a batch_size dimension.