How to do mini batch with dynamic computation graph

Hi all, I am new to framework with dynamic computation graph. I search everywhere but I couldn’t find a reference about how to implement mini-batch with RNN or even tree LSTM with varying length input. So I guess my general problem is how to do mini batch with dynamic computation graph. Thanks.


For RNNs, there’s already a batched variable-length example (on the SNLI dataset; there all you have to do is sort the examples a little so that you can batch together sentences of similar lengths with minimal padding.

TreeRNNs are harder. I’ll add an example soon that does this, but the general idea for TreeRNNs is that batching is up to you as the user, and you should split and concatenate when you need to. So if you use a binary tree structure, you can represent it as a shift-reduce parser (see the SPINN paper from Bowman et al) that means you can process multiple trees in parallel by doing preprocessing like this:

 tree1: ((ab)c)
 tree2: (d(ef))
preprocessed input:
        1        2        3        4        5

and then using advanced indexing to copy all the tokens for SHIFT at each timestep in parallel while concatenating the stack representations for batched REDUCE.

Sorry if this is confusing, I promise an example will be up soon.

I would add that PyTorch is impressively fast on TreeRNNs even without batching.


Thank you for the pointer. So it’s actually up to the user to design batching mechanism. Thank you all for building pytorch with amazing flexibility and great tutorial.

Just curious, is there any plan to release a technical report about the performance of pytorch compared to other framework with support of dynamic computation graph?

For now I think we just have to say that we’re quite fast :slight_smile: We’d rather have someone independently benchmark the frameworks, or do a collaboration where maintainers of each implement the same script. Otherwise the benchmarks can end up being a bit biased, because we don’t know other libraries nearly as well as we do PyTorch.

That make sense. Thanks for the great work.

In this particular SNLI example, there is an import call to module torchtext in
Does that module exist? I can’t find it.

examples/snli/ from torchtext import data from torchtext import datasets

@rituk see this repo.

1 Like

Using BucketIterator, which produces minibatches with minimized paddings, would work ok for SNLI since it’s a sentence classification task. However, for sequence tagging tasks, I think having padded inputs (even though # pads is minimized) without gradient masking won’t be a good idea, since we would need to get gradients from the targets for the padded inputs.

1 Like

Yes, this kind of padding is a stopgap until PyTorch has full masked RNN support, which is on its way.

What if we are facing arbitrary trees rather than binary ones? This can correspond to the childsum treelstm where each node can have different number of children. Is it still possible to batch with the shift-reduce strategy?

It’s possible but a lot harder.

I hear that tensorflow-fold is able to batch trees of arbitrary shapes. Is there similar implementations in pytorch? Why no body is trying to make a tool?

for what it’s worth, I could install torchtext using:

pip install git+

(in a virtualenv environment, otherwise try with --user)

whats an example for a feedforward NN or CNN? I try to index my torch arrays of data and it says I can’t/shouldn’t be using numpy to index things. As in:

def get_batch(X,Y,M):
    N = len(Y)
    valid_indices = np.array( range(N) )
    batch_indices = np.random.choice(valid_indices,size=M,replace=False)
    batch_xs = X[batch_indices,:]
    batch_ys = Y[batch_indices]
    return batch_xs, batch_ys

where X and Y are torch tensors (or variables).

I think my code runs now, but it seems there has to be a better way than doing:

def get_batch2(X,Y,M,dtype):
    X,Y =,
    N = len(Y)
    valid_indices = np.array( range(N) )
    batch_indices = np.random.choice(valid_indices,size=M,replace=False)
    batch_xs = torch.FloatTensor(X[batch_indices,:]).type(dtype)
    batch_ys = torch.FloatTensor(Y[batch_indices]).type(dtype)
    return Variable(batch_xs, requires_grad=False), Variable(batch_ys, requires_grad=False)

I tried a couple of torch methods like gather and index_select but with no luck. Some of the things I tried:

    #valid_indices = torch.arange(0,N).numpy()
    #valid_indices = np.array( range(N) )
    #batch_indices = np.random.choice(valid_indices,size=M,replace=False)
    #indices = torch.LongTensor(batch_indices)
    #batch_xs, batch_ys = torch.index_select(X_mdl, 0, indices), torch.index_select(y, 0, indices)
    #batch_xs,batch_ys = torch.index_select(X_mdl, 0, indices), torch.index_select(y, 0, indices)

wonder if this silly moving from numpy to torch is actually slowing my code down! I hope not.

(as a side note I put my question on SO:
question with this level of detail are welcome to the pytorch forum or not? Or should one stick with SO?)