Unable to use DataParallel + LSTM + batch_first=False + packed_sequence

So the minimal example to produce the error:

device = torch.device('cuda')
lstm = nn.DataParallel(nn.LSTM(1, 5, batch_first=False), dim=1).to(device)
batch_size = 30
max_length = 20
lengths=torch.tensor([max_length]*batch_size, device=device) 
inputs = torch.zeros(max_length, batch_size, 1, device=device)
inputs_pack = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=False)
outputs, hidden = lstm(inputs_pack)

which ends up with an exception:

Dimension out of range (expected to be in range of [-1, 0], but got 1)

Can you include the trace of the exception?

Hi, the trace is quite long.

But now I kind of know what the problem is. Basically, the packed sequence could not be parallelized because it could not be divided along a batch dimension.

The solution is to put LSTM inside another module and pack the sequence inside the forward function of that module.

RuntimeError                              Traceback (most recent call last)                                                                                   
<ipython-input-5-4bf0d856e1ee> in <module>                                                                                                                    
      6 inputs = torch.zeros(max_length, batch_size, 1, device=device)                                                                                        
      7 inputs_pack = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=False)                                                             
----> 8 outputs, hidden = lstm(inputs_pack)                                                                                                                   
                                                                                                                                                              
/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)                               
    487             result = self._slow_forward(*input, **kwargs)                                                                                             
    488         else:                                                                                                                                         
--> 489             result = self.forward(*input, **kwargs)                                                                                                   
    490         for hook in self._forward_hooks.values():                                                                                                     
    491             hook_result = hook(self, input, result)                                                                                                   
                                                                                                                                                              
/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)                       
    137         if not self.device_ids:
    138             return self.module(*inputs, **kwargs)
--> 139         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    140         if len(self.device_ids) == 1:
    141             return self.module(*inputs[0], **kwargs[0])

/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in scatter(self, inputs, kwargs, device_ids)
    148
    149     def scatter(self, inputs, kwargs, device_ids):
--> 150         return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
    151
    152     def parallel_apply(self, replicas, inputs, kwargs):

/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py in scatter_kwargs(inputs, kwargs, target_gpus, dim)
     33 def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
     34     r"""Scatter with support for kwargs dictionary"""
---> 35     inputs = scatter(inputs, target_gpus, dim) if inputs else []
     36     kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
     37     if len(inputs) < len(kwargs):                                                                                                 
                                                                                                               
/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py in scatter(inputs, target_gpus, dim)
     26     # None, clearing the cell                                                                          
     27     try:                                                                                             
---> 28         return scatter_map(inputs)                                                                  
     29     finally:                                                                                        
     30         scatter_map = None                                                                                                                            
                                                                                                                
/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py in scatter_map(obj)                                      
     13             return Scatter.apply(target_gpus, None, dim, obj) 
---> 15             return list(zip(*map(scatter_map, obj)))                                                
     16         if isinstance(obj, list) and len(obj) > 0:                                                                                                    
     17             return list(map(list, zip(*map(scatter_map, obj))))                               
                                                                                                                                                              
/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py in scatter_map(obj)
     13             return Scatter.apply(target_gpus, None, dim, obj)                                                                                         
     14         if isinstance(obj, tuple) and len(obj) > 0:                                                 
---> 15             return list(zip(*map(scatter_map, obj)))                                                                                                  
     16         if isinstance(obj, list) and len(obj) > 0:                                                                                                    
     17             return list(map(list, zip(*map(scatter_map, obj))))                                     
                                                                                                                                                         
/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py in scatter_map(obj)                                 
     11     def scatter_map(obj):                                                                              
     12         if isinstance(obj, torch.Tensor):                                                           
---> 13             return Scatter.apply(target_gpus, None, dim, obj)                                          
     14         if isinstance(obj, tuple) and len(obj) > 0:                                                     
     15             return list(zip(*map(scatter_map, obj)))                                                
                                                                                                            
/lib/python3.6/site-packages/torch/nn/parallel/_functions.py in forward(ctx, target_gpus, chunk_sizes, dim, input)
     87             # Perform CPU to GPU copies in a background stream                                      
     88             streams = [_get_stream(device) for device in target_gpus]                                  
---> 89         outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)                    
     90         # Synchronize with the copy stream                                                                                                            
     91         if streams is not None:                                                            
                                                                                                                
/lib/python3.6/site-packages/torch/cuda/comm.py in scatter(tensor, devices, chunk_sizes, dim, streams)
    146         ``devices``.                                                                                   
    147     """                                                                                              
--> 148     return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))                        
    149                                                                                                      
    150                                                                                                     
                                                                                                            
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
1 Like