About the variable length input in RNN scenario

Hi all, I am recently trying to build a RNN model for some NLP task, during which I found that the RNN layer interface provided by pytorch (no matter what cell type, gru or lstm) doesn’t support masking the inputs. Masking is broadly used in NLP domain for the inputs within a single batch having different length (as inputs are generally bunch of natural language sentences), so just wondering will this be a future feature in pytorch? or I have to find some other way to do the masking myself? Thanks.

3 Likes

Yeah, that’s something we’ll need to and plan to figure out quite soon, as it’s an important feature. For now, you could pad the outputs of the network after the EOS token with some special values that would make the loss be equal to 0. Hopefully we’ll have a solution ready this week.

6 Likes

That’s great! really appreciate your efforts on it

Padding variable length input works reasonably well on CPU (haven’t tried GPU yet). Here are a few examples with “dynamic batching”. Basically for batch that looks like this:

[[0, 0, 1, 1], [1, 1, 1, 1]]

Batch size at time steps 0 and 1 will be 1, and at time steps 2 and 3 will be 2.

I was surprised that dynamic batching was slower. That being said, there is some tricky indexing and concatenations that might have a nicer implementation.

What is dynamic batching? Just iterating over inputs one step at a time, and slicing the batch if some sequence ends?

Dynamic Batching is the exact advantage provided by Tensorflow Fold, which makes it possible to create different computation graph for each sample inside single mini-batch. @mrdrozdov tried to implement dynamic batching in PyTorch and succeed. However, the dynamic batching version of RNN is even slower than the padding version. After profiling his code I found that the hotspot that makes dynamic batching slow is torch.chunk. Result shows that this operation takes even more time than RNN forward pass:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   108                                               @profile
   109                                               def forward(self, x, lengths):
   110       541         1331      2.5      0.0          batch_size = len(x)
   111     17853        20234      1.1      0.2          lengths = [len(s) for s in x]
   112
   113       541          514      1.0      0.0          outputs = [Variable(torch.zeros(1, self.model_dim).float(), volatile=not self.training)
   114     17853       231300     13.0      1.9                     for _ in range(batch_size)]
   115
   116     11522        14014      1.2      0.1          for t in range(max(lengths)):
   117     10981        19603      1.8      0.2              batch = []
   118     10981        15608      1.4      0.1              h = []
   119     10981        14756      1.3      0.1              idx = []
   120    362373       424946      1.2      3.5              for i, (s, l) in enumerate(zip(x, lengths)):
   121    351392       809330      2.3      6.7                  if l >= max(lengths) - t:
   122    267925       399322      1.5      3.3                      batch.append(s.pop())
   123    267925       307910      1.1      2.6                      h.append(outputs[i])
   124    267925       300516      1.1      2.5                      idx.append(i)
   125
   126     10981       316257     28.8      2.6              batch = np.concatenate(np.array(batch).reshape(-1, 1), 0)
   127     10981       161699     14.7      1.3              emb = Variable(torch.from_numpy(self.initial_embeddings.take(batch, 0)), volatile=not self.training)
   128     10981       522216     47.6      4.3              h = torch.cat(h, 0)
   129     10981      2529893    230.4     21.1              h_next = self.rnn(emb, h)
   130     10981      4748304    432.4     39.5              h_next = torch.chunk(h_next, len(idx))
   131
   132    278906       322694      1.2      2.7              for i, o in zip(idx, h_next):
   133    267925       474999      1.8      4.0                  outputs[i] = o
   134
   135       541        27823     51.4      0.2          outputs = torch.cat(outputs, 0)
   136       541       174478    322.5      1.5          h = F.relu(self.l0(F.dropout(outputs, 0.5, self.training)))
   137       541       152165    281.3      1.3          h = F.relu(self.l1(F.dropout(h, 0.5, self.training)))
   138       541        25429     47.0      0.2          y = F.log_softmax(h)
   139       541          585      1.1      0.0          return y

So, is there any other efficient alternative to torch.chunk? Is there any approach to implement dynamic batching in PyTorch efficiently?

1 Like

For now, you have to use the padding approach in order to take advantage of the substantial speedup afforded by CUDNN’s accelerated RNN kernels. If you only need a unidirectional RNN, you can mask the resulting tensors and remove the effects of the padding completely. If you want variable-sequence-length support with a bidirectional RNN, or would like true dynamic batching that doesn’t even run computations for padding tokens, CUDNN actually supports this internally but PyTorch does not yet have a wrapper (expect one fairly soon).

BTW, are there benchmarks on TF Fold? I can’t imagine their repeated concatenations+splits are all that much faster than they’d be in PyTorch.

4 Likes

There’s no faster alternative to chunk at the moment. But if it’s a bottleneck for some applications we can speed it up for sure. I’ll try to take a look at the code sometime.

1 Like

Additionally, I think the code could be greatly simplified and could completely ignore chunk if only it sorted the sequences by length.

Looking forward to it very much.

Also wondering whether this strategy is faster than padding + masking solution.

For now, I will use the old-fashion padding (maybe with masking) approach.

Padding + masking might have some advantage on the GPU, because you can use cuDNN RNN kernels that parallelize computation across multiple timesteps, and the more data you give them, the more efficient they’ll get.

@apaszke would definitely be open to suggestions in this direction! The data that I’ve worked with has batches that look closer to something like:

0010011
0000111

Where everything marked 1 at a timestep is involved in a batched RNN op. Not clear to me how to get away without using torch.chunk.

Here’s another example:

001001001001111
000011100001111
000001110001111
000000011111111

When is such format used? I assume that each line is an independent batch element, and never interacts with other ones, right? We do you have these blanks in the data?

1 Like

As @jekbradbury mentioned, just padding and masking outside RNN modules won’t correctly work for bidirectional cases.
How about first following the cudnn approach, and considering other approaches later?

We’re going to add cuDNN variable length bindings soon.

5 Likes

I see that you’ve pushed the variable length RNN support to main branch. Does it take care of the bidirectional case?

Yes. Feel free to check it out – none of the examples have been updated to use it yet, but we’ll update SNLI and OpenNMT soon.

4 Likes

Excellent work! Can’t wait to look at your examples!:smile:

I don’t understand how to use torch.nn.utils.rnn.PackedSequence. Does it have anything different from padding the sequence myself ?

1 Like

Yes, it’s different. You need to build a padded sequence yourself, then pass it into nn.utils.rnn.pack_padded_sequence; the resulting object is a little confusing, but it’s the only format that nn.LSTM will accept and then process as if there were no padding anywhere. You could manually simulate the same result using a unidirectional nn.LSTM, but it would be impossible to completely replicate what this does with a bidirectional nn.LSTM.

1 Like