Problem with torchvision functional rotate

Greetings!

I have been trying to perform some data augmentation on my tensors, and I wanted to rotate them. This is a basic example of the code I am working with so far:

import torch
import torch
from torchvision.transforms.functional import rotate

img = torch.randn(4, 2, 11, 256, 256)

rotated_img = rotate(img, +90)

However, when trying to run it, I get the following error:

Traceback (most recent call last):
  File "basicrotate.py", line 7, in <module>
    rotated_img = rotate(img, +90)
  File "/home/javier/directory_env/spikingjelly_env/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 1078, in rotate
    return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
  File "/home/javier/directory_env/spikingjelly_env/lib/python3.8/site-packages/torchvision/transforms/functional_tensor.py", line 686, in rotate
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
  File "/home/javier/directory_env/spikingjelly_env/lib/python3.8/site-packages/torchvision/transforms/functional_tensor.py", line 586, in _apply_grid_transform
    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
  File "/home/javier/directory_env/spikingjelly_env/lib/python3.8/site-packages/torch/nn/functional.py", line 4201, in grid_sample
    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
RuntimeError: grid_sampler(): expected 4D or 5D input and grid with same number of dimensions, but got input with sizes [4, 2, 11, 256, 256] and grid with sizes [4, 256, 256, 2]

Could any of you help me fix this issue? According to the docs, the input to rotate can be any tensor of shape [… H, W], arbitrarily large before the last two dimensions, so I really don’t know what I’m missing here… Than you very much for your help!

PS.- for additional information:
Torch version: 1.11.0+cu102
Torchvision version: 0.12.0+cu102

Javier, this doesn’t address your direct question (why doesn’t it work per what the docs say) but as a quick workaround, you could try (using the dimensions from your example):

rotated_img = rotate(img.view(4, 2 * 11, 256, 256), +90).view(4, 2, 11, 256, 256)

1 Like

Thanks! That is precisely what I am doing so far, but I wanted to know if there was a native, more efficient way to make it work :sweat_smile: