Memory efficient repeat with shift

So my problem is as follows: I start with a tensor of shape (h, w), and a tensor of shape (n, 2) - which is guidance for n torch.roll operations. Then from these 2 tensors I want to costruct a (n, h, w) tensor, where the nth slice is the (h, w) input with the torch.roll operation applied (with the 2 arguments guiding how much to roll over each axis coming from the (n, 2) tensor).

Would there be any way to do this without allocating (h, w, n) memory? Just as torch.expand doesn’t allocate new memory, but rather returns a view of the original tensor.

this would be really easy and intuitive with torch.dim, but it requires a PyTorch nightly version (hasn’t gotten to stable releases yet): GitHub - facebookresearch/torchdim: Named tensors with first-class dimensions for PyTorch

Short of that, I think your best bet is to write a Numba (CPU) or Triton (GPU) kernel.

1 Like