Can I make this tensor sliding faster?

Hello
I am doing something like this in my code but it is very slow.
I was wondering if there is a faster way to do it.
in my example shift and ndivides are subject to change from case to case.

x = torch.arange(0,32*400).reshape(1,32,400)
cat = []
shift = 4
ndivides = 99
for i in range(ndivides):
    cat.append(x[:,:, i*shift:(i+1)*shift +(shift //2)*2])
results = torch.stack(cat,dim=1)

sometimes if I change shift or ndivides, it wont work, but that is not an issue, the main issue is that the loop is too slow, unfortunately.

tensor.unfold should work:

x = torch.arange(0,32*400).reshape(1,32,400)
cat = []
shift = 4
ndivides = 99
for i in range(ndivides):
    cat.append(x[:,:, i*shift:(i+1)*shift +(shift //2)*2])
results = torch.stack(cat,dim=1)


tmp = x.unfold(2, 8, 4).permute(0, 2, 1, 3)
print((tmp == results).all())
> tensor(True)
1 Like

Thank you ptrblk.
in the unfold version, if I want to make it general with respect to shift, how can I do it?
should it be x.unfold(2, 8, shift ).permute(0, 2, 1, 3)?

Yes, this should yield the same results as your previous code once you adapt the ndivides value for the new shift:

tmp = x.unfold(2, shift+(shift//2)*2, shift).permute(0, 2, 1, 3)
1 Like