Tensor shifts in torch.roll

Hi all!

I have multiple images (or feature maps). I want roll the images using torch.roll.

 torch.roll(images, shifts, dims=(2,3))

shifts must be int or tuple of ints.
But I have shifts tensor with different amount for each image.
Is there a way to do roll each image with different indices in vectorized manner?

Thank you!

Hi Meta!

Yes, you can create “identity” indices for each of the four dimensions of
your images tensor. Then “roll” the identity indices for your height and
width dimensions (your dims = (2, 3)) on a batch basis and create your
rolled images by using pytorch tensor indexing to index into images with
your “rolled” indices.

Here is an illustration:

import torch

nBatch = 5
nChannels = 3
h = 24
w = 32

images = torch.randn (nBatch, nChannels, h, w)
shifts = torch.stack ((torch.randint (h, (nBatch,)), torch.randint (w, (nBatch,))), -1)

ind0 = torch.arange (nBatch)[:, None, None, None].expand (nBatch, nChannels, h, w)
ind1 = torch.arange (nChannels)[None, :, None, None].expand (nBatch, nChannels, h, w)
ind2 = torch.arange (h)[None, None, :, None].expand (nBatch, nChannels, h, w)
ind3 = torch.arange (w)[None, None, None, :].expand (nBatch, nChannels, h, w)

rolled_images = images[ind0, ind1, (ind2 + shifts[:, 0, None, None, None]) % h, (ind3 + shifts[:, 1, None, None, None]) % w]

Best.

K. Frank

1 Like