How to create batches of a list of varying dimension tensors?

I am trying to create batches for my training. my inputs are tensors with varying dimension.

Let’s say I have a list of tensors for source (input) and target (output).
tensors in the source are of varying dimension. For example: torch.Size([1, 100]), torch.Size([2, 100]), torch.Size([3, 100]), torch.Size([4, 100])
but tensors in target are all torch.Size([1, 100])

I want to have batches with the same dimensions. like all the [1,100] be batched together. all [2,100] batched together. and consequently I want to have the corresponding target tensors paired with source tensors.

Can anyone help me or point me to the right direction? Many thanks

Note: The thing is dim(0) which is varying here can be considered similar to sequence length when we do normal language modeling or seq2seq.

1 Like

Hi,
Usually with different sequence length you can pad all inputs to become the same length. After padding a sequence, if you are using an torch.nn RNN block such as LSTM() or GRU(), you can use pack_padded_sequence to feed in a padded input.

Otherwise you could create batches according to the length of the 0th dimension, like you said, but that might be inefficient. And using torch.utils.data.Dataset and torch.utils.data.DataLoader won’t be possible.

Thanks @Prerna_Dhareshwar.
But my only concern here is that this tensors are actually vector representations ( 100 dim vectors) . They are not real sequences.
In that case, for example if I want to pad the first input which is of dimension torch.Size([1, 100]) , the padded version would be torch.Size([5, 100]) in which all element [1:,:] would be the ‘padding’ index?

Doesn’t it make problem during training?

Why do you think it would be a problem during training?

Not exactly sure but maybe because most of the time dim(1) will be padded to make them equal size.
However here, dim(1) of my tensors are actually model dimension.
But I think you are right, I am going to try that and see how that works.
Just to double check:
when using padding then all the batches in my case would be of size [5,100] ? right?
and if we want to make it efficient by batching similar size tensors together, we have to write our customize Iterator and then torch.utils.data.Dataset and torch.utils.data.DataLoader are not usable?

Hi,

It shouldn’t be an issue even if you’re padding sequences of size 1.
Yes, after padding, all your sequences will have same length. Make sure you read the documentation for pack_padded_sequence to understand how exactly it works.

Yes you are correct, since DataLoader() will create the batches and it will be hard to control the training examples that go in a batch.

I think I know what’s the problem:
Since this tensors are actually real float values, if I want to pad them with ‘0’, ‘0’ might have a meaning in vector space.
It’s not similar to case where everything is token indices and we have a padding index and the embedding will transfer it into vector space.
These tensors are actually in vector space and I am not sure what padding value should I use.

So it’s okay to just pad with 0’s because when you create a pack_padded_sequence, apart from the padded input, another input you provide is the actual length of the sequence. So whatever you pad with after this is not considered in the output of the RNN. For eg: If your input has size x.shape = (1,1,100) which is (sequence, batch, input dimension), and you pad it such that x_padded.shape = (5,1,100), then you create a pack_padded_sequence as follows-

x_packed = pack_padded_sequence(x_padded, [1])

The second argument [1] indicates that the length of the sequence x is only 1. So it doesn’t consider the values after the first value. Does this make sense?

2 Likes

Oh I see, that totally make sense now, Thanks a lot.

I am sorry, I just got confused with the steps,
I should pad the sequences, then use DataLoader() to create batches and then pack_padded_sequence to feed in inputs?

and can I use pack_padded_sequence for any deep learning networks? like Transformer?

Yes, you’re right, you pad and use pack_padded_sequence to feed in inputs.

You can use it wherever it specifies in the docs that it accepts inputs of that format. I’m not sure what all that includes.

1 Like

It seems like PackedSequence Object can only be use by nn.RNN , nn.LSTM , nn.GRU.
How we can creates batches to use them in our own defined function/classes which in my case is mostly Attentions (dot product)

I would appreciate any help.

I guess that is because only in RNNs does the fact that the data is a sequence matter - i.e., memory is retained from earlier parts of the sequence even at the end of the sequence. Do you need to retain the data in this ‘sequence’ format for your non-RNN application? What exactly are you trying to do?

I want to use these as input to a Transformer network. Essentially I want to use these tensors in the attention layers.
Hypothetically imagine that I have the word embeddings so I don’t need the embedding layer of transformer,
but I want to use the position embedding since, the position of these tensors are important for me. That is for the input X of size [2,100] , the notion of sequence is important.

I think it could be possible for you to implement something like PackedSequence implements, but for your case. In other words, you can manually pad your “sequences” (they are really not, as you mentioned) and pass their “lengths” along with them through the Dataset collate_fn function, so you know how to get the essential values back (i.e., you have a size of (5,100), but in fact you only need [0:2, :] from this vector). In practice, you could do something like it:

def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    _, labels, lengths = zip(*data)
    max_len = max(lengths)
    n_ftrs = data[0][0].size(1)
    features = torch.zeros((len(data), max_len, n_ftrs))
    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)

    for i in range(len(data)):
        j, k = data[i][0].size(0), data[i][0].size(1)
        features[i] = torch.cat([data[i][0], torch.zeros((max_len - j, k))])

    return features.float(), labels.long(), lengths.long()

The function above is fed to the collate_fn param in the DataLoader, as this example:

DataLoader(toy_dataset, collate_fn=collate_fn, batch_size=5)

With this collate_fn function, you always gonna have a tensor where all your examples have the same size. So, when you feed your forward() function with this data, you need to use the length to get the original data back, to not use those meaningless zeros in your computation.

15 Likes

what is the use of returning lengths?

also, how would you use the built in torch.nn.utils.rnn.pad_sequence in your example?

I did some examples and this is what I got:

def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## padd
    batch = [ torch.Tensor(t).to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask

I think it should work for arbitrary sequences of tensors with weird dimensions.

1 Like

What is the use of returning lengths?

I think it can be better understood by an example. Suppose we have a problem where I need to classify whether a video is a teaching class or a person practicing a sport. Videos are composed of a set of frames, where each one is an image. So, one way to go is to pass the set of frames through an LSTM model. Now suppose every frame of our video have torch.size([C, H, W]), where C is the RGB channels, H is the height and W is the width of the image. We also have a set of videos, in any case, every video might have a different length; therefore, a different number of total frames. For example, the video1 have 5350 frames while video2 have 3323 frames. You can model video1 and video2 with the following tensors: torch.size([5350, C, H, W]) and torch.size([3323, C, H, W]) respectively. As you can see, both tensors have different sizes in the first dimension, which prevents us from stacking both tensors in only one tensor. To make this happens, we can save a tensor called lengths = [5350, 3323] and then pad all videos tensors with zeros to make them have equal length, i.e., both have the size of the biggest length, which is 5350, resulting in two tensors with the following shape: torch.size([5350, C, H, W]). Then, after that, we can stack both tensor to obtain only 1 tensor with the following shape: torch.size([2, 5350, C, H, W]), which means that 2 is the batch_size (you can stack them with this function). But, as you can see, we have lost the information on the sequence when stacking both tensors, which means that for the tensor of video2, all examples of video2_tensor[3324:, ...] will have 0 as values. To remedy this, we need to use the lengths vector to get the original sequence back, and not a bunch of zeros.

also, how would you use the built in torch.nn.utils.rnn.pad_sequence in your example

Yes! You could use it, and your code seems fine to me. But why the mask = (batch != 0).to(device) line?

7 Likes

seems there are many related posts. I collected a list to have them all accessible in one place:

bucketing:


Stack-overflow question:


crossposted: https://www.quora.com/unanswered/How-does-Pytorch-Dataloader-handle-variable-size-data

2 Likes

Just to add something, it is a lot more efficient to pad batch by batch (via collate_fn) rather than padding the whole dataset, so I strongly encourage to do that. I have an example for sequence tagging in https://github.com/marctorrellas/conl2002. Feel free to ask if any doubts

4 Likes

Thanks for this! I just had a follow up question - once I pass the batch to the forward function, the only solution I can think of is to iterate through each individual tensor in order to remove padding (since each tensor has its own length). Is there a more elegant way of doing this? Iterating through each tensor in the batch would be very inefficient and time consuming

2 Likes