Hi Meta!
Yes, you can create “identity” indices for each of the four dimensions of
your images
tensor. Then “roll” the identity indices for your height and
width dimensions (your dims = (2, 3)
) on a batch basis and create your
rolled images
by using pytorch tensor indexing to index into images
with
your “rolled” indices.
Here is an illustration:
import torch
nBatch = 5
nChannels = 3
h = 24
w = 32
images = torch.randn (nBatch, nChannels, h, w)
shifts = torch.stack ((torch.randint (h, (nBatch,)), torch.randint (w, (nBatch,))), -1)
ind0 = torch.arange (nBatch)[:, None, None, None].expand (nBatch, nChannels, h, w)
ind1 = torch.arange (nChannels)[None, :, None, None].expand (nBatch, nChannels, h, w)
ind2 = torch.arange (h)[None, None, :, None].expand (nBatch, nChannels, h, w)
ind3 = torch.arange (w)[None, None, None, :].expand (nBatch, nChannels, h, w)
rolled_images = images[ind0, ind1, (ind2 + shifts[:, 0, None, None, None]) % h, (ind3 + shifts[:, 1, None, None, None]) % w]
Best.
K. Frank