How to do this operation fast?

I have some video clips with different length. The video clip length is stored in seq_len = [T_1, T_2, …, T_n]. The video clips are stored in a tensor of shape [\sum T_i, C, H, W].

Now I want to converted it into shape [B, C, \max T_i, H, W], i.e., pad the shorter video clips with zero and concatenate them into the shape that can be sent to nn.Conv3d.

Here is an implementation using loops:

def pack_x(x, seq_len):
    T = seq_len.max().item()
    _, C, H, W = x.shape
    x = x.transpose(0, 1)
    x = torch.split(x, seq_len.tolist(), dim=1)
    
    pack_x = []
    for clip in x:
        clip = F.pad(clip, [0,0,0,0,0,T-clip.size(1)])
        pack_x.append(clip.unsqueeze(0))
    x = torch.cat(pack_x)
    return x


seq_len = torch.randint(5, 10, size=[5])
inputs = torch.randn(seq_len.sum(), 512, 28, 28)
y = pack_x(inputs, seq_len)

Override the function of collate_fn in DataLoader,

torch.utils.data.DataLoader(dataset=train_dataset, batch_size=setting.batch_size, num_workers=opt.nThreads, collate_fn=pack_x_new)

You should do something to update pack_x to pack_x_new which consumes the data from __get_item__() function.

Hi @Naruto-Sasuke I need to do this in the middle of model forward.