Faster way to apply padding to batch 3D tensor?

Hello,

I am applying replicate left-sided i-padding to a tensor X with shape

torch.Size([1024, 5, 10, 50])

to obtain a tensor of size

torch.Size([1024, 5, 10, 55])

using the command
F.pad(X,pad=(i,0,0,0),mode='replicate')

Timing this operation in my notebook with %timeit yields:

for i in range(1,20):
    print(i,end=': ')
    %timeit F.pad(X,pad=(i,0,0,0),mode='replicate')
1: 2.9 ms ± 89.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2: 2.92 ms ± 97.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3: 3.05 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4: 3.1 ms ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5: 3.1 ms ± 72.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6: 3.21 ms ± 35.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7: 3.2 ms ± 89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
8: 3.28 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9: 3.4 ms ± 62.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10: 3.42 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
11: 3.52 ms ± 89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12: 3.58 ms ± 96.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
13: 3.63 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
14: 3.75 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
15: 3.82 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16: 3.82 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
17: 3.9 ms ± 62.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
18: 3.94 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19: 4.05 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

This seems to be quite slow, if it needs to be done over and over again (also that it slows down considerably with growing size of the padding is a bit surprising). Is there a way to do it faster? Else, I would try to apply the padding to X while it is still a numpy array, although I havent tested whether this would be faster, yet.

Thanks!

Best, JZ

Hi,

maybe I did not understand correctly, but if you ONLY want to pad (with zeros) on the left side, then you could use the following code.

your_matrix = torch.rand(1024, 5, 10, 50)
zero_matrix = torch.zeros(1024, 5, 10, 5)

print(your_matrix.shape)

your_matrix = torch.cat((zero_matrix, your_matrix), dim=3)
print(your_matrix.shape)

If that is not the case, please let me know.

Hey Matias,

thanks for your reply! I’d rather like to pad with replication mode, as in my example above. So basically, filling the left-sided pad with the leftmost entries in every matrix row. Based on your idea of using torch.cat, with torch.Tensor.expand, given

x = torch.rand(1024, 5, 10, 50).to('cuda')

I wrote this small function:

def batch_pad_left(x,pad):
    left_col = x[:,:,:,0].unsqueeze(-1)
    left_col_expand = left_col.expand((1024,5,10,pad))
    x_padded = torch.cat([left_col_expand,x],dim=-1)
    return x_padded

Comparing the outputs of F.pad and this function

(F.pad(x,pad=(5,0,0,0),mode='replicate')==batch_pad_left(x,pad=5)).all()

yields:

tensor(True)

So, this works.

Timing the operations with

for i in range(1,20):
    print(i,end=': ')
    %timeit F.pad(x,pad=(i,0,0,0),mode='replicate')

and

for i in range(1,20):
    print(i,end=': ')
    %timeit x_padded = batch_pad_left(x,pad=5)

returns (F.pad)

1: 35.2 µs ± 16.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2: 209 µs ± 7.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3: 293 µs ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
4: 326 µs ± 5.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5: 333 µs ± 5.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
6: 332 µs ± 7.72 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
7: 332 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
8: 339 µs ± 6.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
9: 360 µs ± 20.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
10: 377 µs ± 4.85 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
11: 381 µs ± 5.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
12: 383 µs ± 4.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
13: 310 µs ± 67.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
14: 414 µs ± 182 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
15: 608 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
16: 578 µs ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
17: 576 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
18: 276 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
19: 540 µs ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

and (cat+expand)

1: 607 µs ± 74.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2: 357 µs ± 113 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3: 654 µs ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
4: 640 µs ± 26.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5: 660 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
6: 318 µs ± 53.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
7: 467 µs ± 7.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
8: 466 µs ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
9: 468 µs ± 8.24 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
10: 465 µs ± 5.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
11: 467 µs ± 7.94 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
12: 465 µs ± 7.37 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
13: 466 µs ± 6.98 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
14: 585 µs ± 136 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
15: 631 µs ± 25.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
16: 651 µs ± 32.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
17: 554 µs ± 170 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
18: 455 µs ± 190 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
19: 664 µs ± 24.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

respectively for F.pad and expand+cat. (Note that the significant speedup from ms to mus results from doing this on my cuda device this time, rather than the cpu.)

Unfortunately, using the expand+cat function seems to slow down the operation in most cases. Maybe, there is a better way to write that function? I am still new to torch operations. Otherwise, I will probably have to go with F.pad.

Best, JZ