Optimizing diagonal stripe code

No sure why you would like to set the argument dim=3 when the tensor a is of dimension 3 thus only has dimensions 0,1,2. Maybe this is what you want?

import torch


def flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
                                dtype=torch.long, device=x.device)
    return x[tuple(indices)]


def batch_stripe(a):
    b, i, j = a.size()
    assert i >= j
    b_s, k, l = a.stride()
    return torch.as_strided(a, (b, i-j+1, j), (b_s, k, k+l))


if __name__ == '__main__':
    a = torch.arange(24).view(2,4,3)
    print('original tensor:')
    print(a)
    print('inverse stripe:')
    print(batch_stripe(flip(a, -1)))
1 Like