Multiple Conv1d (with different kernal size) over same input

I have defined a model which performs convolutions over a batch of character-sequences.

kernals = [3, 4, 5, 6]
cnns = []
for k in kernals:
  seq = nn.Sequential(
                      nn.Conv1d(char_embed_size, output_size // 4, k, padding=0),
                      nn.Tanh(),
                      nn.MaxPool1d(max_seq_length - k + 1)
                      )
  cnns.append(seq)
self.cnns = nn.ModuleList(cnns)

In my forward method I obtain a representation for each sequence using:

def forward(self, char_emb):
  #char_emb has shape (batch_size, char_emb_size, max_seq_len)
  tmp = [cnn(char_emb).squeeze() for cnn in self.cnns]
  seq_representations = torch.cat(tmp, dim=1)
  return seq_representations

Is there a way to avoid the synchronous loop [cnn(char_emb).squeeze() for cnn in self.cnns] and have all the cnns in self.cnns to parallelly perform convolutions over the input?

1 Like

CUDA kernels are run asynchronously, so your list comprehension isn’t a synchronization point.

But I see a slow down per kernel added. Is that just overhead due to other factors? I don’t have exact timing numbers but 1 kernel is definitely lot faster than 4.

oh right, you should run them on different cuda streams https://pytorch.org/docs/master/cuda.html#streams-and-events.

but in general you should expect slow down in paralleling tasks. there will always be overhead and usually not all tasks can run with 100% speed.

Thanks for the quick reply! thanks for the pointer regarding streams and events. It’s not clear (I did a quick read through) Are there any working examples of streams you can refer me to?

Within a CUDA stream, kernels are run sequentially. But different streams are run in parallel. By default all ops are run on stream 0, so I suggest you can try running each forward pass in a separate stream. :slight_smile:

1 Like

Thanks again! I am now doing

def forward(self, emb):
  #old way   
  tmp = [cnn(emb).squeeze() for cnn in self.cnns]
  seq_representation = torch.cat(tmp, dim=1)

  #new way
  stream_tmp = []
  streams = [(idx, torch.cuda.Stream()) for idx, cnn in enumerate(self.cnns)]
  for idx, s in streams:
    with torch.cuda.stream(s):
      cnn = self.cnns[idx]     #<--- how to ensure idx is in sync with the idx in for loop?
      stream_tmp.append((idx, cnn(emb).squeeze()))
  stream_tmp = [t for idx, t in sorted(stream_tmp)]
  seq_representation_stream = torch.cat(stream_tmp, dim=1)  
  
  #comparing the two
  diff = abs(seq_representation_stream - seq_representation).sum().data[0])
  print(diff)
  assert diff == 0.0
  return seq_representation

In some random batches the assert fails (the diff is very large > 1000 so its not a rounding error).

I am pretty sure it is because the idx in the for loop is not in sync with the idx inside the with torch.cuda.stream(s) block. Sorry this is more of a python question that a pytorch question – but from the documentation, it is not clear how to open multiple streams and concat their results.

1 Like

You should synchronize all the streams after the for loop (torch.cuda.synchronize). Because the cat is run on default stream, when it is run, other streams may not have finished.

I added torch.cuda.synchronize() as you said (that could have been the problem some of the time) but the assertion still fails on some random batches. I suspect that idx is getting mixed up between the for loop and the with code block. And that will make my concat happen in the wrong order (since I’m using idx to sort the intermediate results)

def forward(self, emb):
  #old way   
  tmp = [cnn(emb).squeeze() for cnn in self.cnns]
  seq_representation = torch.cat(tmp, dim=1)

  #new way
  stream_tmp = []
  streams = [(idx, torch.cuda.Stream()) for idx, cnn in enumerate(self.cnns)]
  for idx, s in streams:
    with torch.cuda.stream(s):
      cnn = self.cnns[idx]     #<--- how to ensure idx is in sync with the idx in for loop?
      stream_tmp.append((idx, cnn(emb).squeeze()))
  torch.cuda.synchronize() # added synchronize
  stream_tmp = [t for idx, t in sorted(stream_tmp)]
  seq_representation_stream = torch.cat(stream_tmp, dim=1)  
  
  #comparing the two
  diff = abs(seq_representation_stream - seq_representation).sum().data[0])
  print(diff)
  assert diff == 0.0
  return seq_representation