How to do padding based on lengths?

I have a list of sequences and I padded it to the same length (emb_len). I have a separate tensor that I want to concat it to every data point in the sequences.

Intuitively, it is something like this

a b c d e f g 0 0 0
u u u u u u u u u u

h i j k l 0 0 0 0 0
u u u u u u u u u u

but the correct one (I suppose) would be

a b c d e f g 0 0 0
u u u u u u u 0 0 0

h i j k l 0 0 0 0 0
u u u u u 0 0 0 0 0

I did something like

torch.cat([seq_embed, 
             torch.cat([second_embed.unsqueeze(1).expand(batch_size, emb_len, second_emb_len), 
             torch.zeros([batch_size, second_embed.size(1) - emb_len, second_emb_len], dytpe=torch.long)], 1)]
            ,2)

However, this is not going to be working because emb_len is a tensor with variable numbers, which is something like torch.LongTensor([1,2,3,4,5]) and there will be errors like a Tensor with # elements cannot be converted to Scalar. Thus, is there any way to solve this problem

2 Likes

I think you are looking for torch.nn.utils.rnn.pad_sequence.

If you want to do this manually:

  • One greatly underappreciated (to my mind) feature of PyTorch is that you can allocate a tensor of zeros (of the right type) and then copy to slices without breaking the autograd link. This is what pad_sequence does (the source code is linked from the “headline” in the docs). The crucial bit is:
    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        # use index notation to prevent duplicate references to the tensor
        if batch_first:
            out_tensor[i, :length, ...] = tensor
        else:
            out_tensor[:length, i, ...] = tensor

If the tensors require grad, so will out_tensor and the gradients will flow back to the tensors in the list.

  • Another way to do this, that seems closer to your description, is to use a cat (or pad) in a list comprehension and push that to another cat.
# setup
import torch
l = [torch.tensor([1,2,3]), torch.tensor([4,5]),torch.tensor([6,7,8,9])]
emb_len=4
# this is what you want:
lp = torch.stack([torch.cat([i, i.new_zeros(emb_len - i.size(0))], 0) for i in l],1)

Best regards

Thomas

5 Likes

I wrote a simple one recently after reading nn.utils.rnn.pad_sequence. Hope this helps.

def padding_tensor(sequences):
    """
    :param sequences: list of tensors
    :return:
    """
    num = len(sequences)
    max_len = max([s.size(0) for s in sequences])
    out_dims = (num, max_len)
    out_tensor = sequences[0].data.new(*out_dims).fill_(0)
    mask = sequences[0].data.new(*out_dims).fill_(0)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        out_tensor[i, :length] = tensor
        mask[i, :length] = 1
    return out_tensor, mask
4 Likes

Hi Tom,

I do not think there is a way to use pad_packed_seq if you want to concat two tensor “vertically”. In my example, the upper line (abcdefg000) could not be combined with the lower line (u for any point, 0 for any padding). So what I have to do is to manually pad the upper sequence (abcdefg => abcdefg000) and find someway to do stack with u.

The second solution seems to be working. I will give it a shot today and let you know.

Thanks!

I think you can pack 2d (seq len, 0/1) tensors using pad_sequence, but you would need to concatenate first. You could do your own indexing variant (by writing into 2i and 2i+1, I would expect that to be more efficient than many cats).
Another option might be to first pad the data and then get the mask (padded_data > 0) from the joint padded tensor or so.

Best regards

Thomas

Just realized that the stack one is also not solving the problem. Let me rephrase the task a little bit.

What I have here is a batch of,

padding_data_1: e.g., abc000, efgh00. Everyone has the same length, padded with 0, where a, b, c, …, stand for embedding vectors.

data_1_len: e.g., [3,4]. The original length of data_1 before padding

data_2: a single embedding for each data point, e.g., [i, j], where i, j stand for embedding vectors.

The target is then

[abc000 ] [ efgh00]
[iii000 ] and [ jjjj00]

(Concatenated vertically with padding)

You can adapt the solution with copying to the large tensor from above, (I must admit I still don’t fully understand whether you want one or two tensors as a result): I think you can do

out_tensor[i, :length, ...] = single_embedding_per_point[i, None]

and similar to use broadcasting (with the new singleton dimension generated by the None index for the :length index on the left hand side). Use single_embedding_per_point[i][None] if it’s a list of tensors rather than one large tensor.

Best regards

Thomas

This is a working solution without taking care of some corner cases. Will this be clear on what I am doing? Do you have any suggestions on how to improve the efficiency of the code? Thanks!

torch.cat([first_embed, 
               torch.stack([
                    torch.cat([
                          _emb.expand(_s_len, 10), 
                          torch.zeros(PADDING_LEN - _s_len, 10, dtype=torch.float)]
                    , 0) 
                    for _s_len, _emb in zip(original_len, second_embed)]
               , 0)
, 2)

where first_embed is padding_data_1, data_1_len is original_len and second_embed is data_2 in previous replies.

FE_DIM = 100
first_embed = torch.randn(5,20,FE_DIM)
second_embed = torch.randn(5, 10)
original_len = [10, 8, 9, 7, 3]
PADDING_LEN = 20
def first():
    res = torch.cat([first_embed, 
               torch.stack([
                    torch.cat([
                          _emb.expand(_s_len, 10), 
                          torch.zeros(PADDING_LEN - _s_len, 10, dtype=torch.float)]
                    , 0) 
                    for _s_len, _emb in zip(original_len, second_embed)]
               , 0)
          ], 2)
    return res
def second():
    res2 = torch.zeros(5, 20, FE_DIM+10)
    res2[:,:,:FE_DIM] = first_embed
    for i, e in enumerate(second_embed):
        res2[i, :original_len[i], FE_DIM:] = e
    return res2

Which looks better? Probably a question of style, my style is the second.
A quick

%timeit first()
%timeit second()

says

10000 loops, best of 3: 111 µs per loop
10000 loops, best of 3: 76 µs per loop

so the second is slightly faster (but it might be different for your parameters).

Best regards

Thomas

I used torch.nn.utils.rnn.pad_sequence for my dataloader class:

def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## padd
    batch = [ torch.Tensor(t).to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask

Many Many related posts:

bucketing:


Even in Stack overflows there is a question about this:


crossposted: https://www.quora.com/unanswered/How-does-Pytorch-Dataloader-handle-variable-size-data

Here is a generalised version for a list of jagged 2-dim tensors which have all but the first dim varying in length:

def padding_tensor(sequences):
    """
    :param sequences: list of tensors
    :return:
    """
    num = len(sequences)
    max_len = max([s.shape[0] for s in sequences])
    out_dims = (num, max_len, *sequences[0].shape[1:])
    out_tensor = sequences[0].data.new(*out_dims).fill_(0)
    mask = sequences[0].data.new(*out_dims).fill_(0)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        out_tensor[i, :length] = tensor
        mask[i, :length] = 1
    return out_tensor, mask