No sure why you would like to set the argument dim=3
when the tensor a
is of dimension 3 thus only has dimensions 0,1,2
. Maybe this is what you want?
import torch
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
def batch_stripe(a):
b, i, j = a.size()
assert i >= j
b_s, k, l = a.stride()
return torch.as_strided(a, (b, i-j+1, j), (b_s, k, k+l))
if __name__ == '__main__':
a = torch.arange(24).view(2,4,3)
print('original tensor:')
print(a)
print('inverse stripe:')
print(batch_stripe(flip(a, -1)))