I see that there is a way to flip the matrix, which would allow me to use the original implementation. The flip function is defined here https://github.com/pytorch/pytorch/issues/229:
def flip(x, dim):
dim = x.dim() + dim if dim < 0 else dim
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
So I can do:
def batch_stripe(a):
b, i, j = a.size()
assert i >= j
b_s, k, l = a.stride()
return torch.as_strided(a, (b, i - j, j), (b_s, k, k+1))
# this simulates numpy stripe(a[..., ::-1, :])[..., ::-1, :] ???
output = flip(batch_stripe(flip(a), 3)), 3)
but when I set .dim to 3 I get this:
RuntimeError: dimension out of range (expected to be in range of [-3, 2], but got 3)
Is dim actually referring to tensor rank in this case in PyTorch? dim=2 seems to work, but I am not sure if that is correct.