Vmap causing TypeError with torchvision rotate

I’m running into an issue trying to vmap over the torchvision rotate function. Rotate() requires an int or float input and does not accept single valued tensors. However vmap requires all batched inputs come as tensors. Additionally, one can’t create a helper function that calls rotate(im, angle.item()), as .item() calls are not allowed in vmap.

I feel like there should be a simple solution to this, but I’m not seeing it. Any help would be appreciated.

from functorch import vmap
from torchvision.transforms.functional import rotate
vrot = vmap(rotate, in_dims=(0,0), out_dims=0)
b, dim = 10, 64
inp_ims = torch.rand((b, dim, dim))
angles = torch.rand((b))
vrot(inp_ims, angles)

This gives: TypeError: Argument angle should be int or float

This might be a known limitation and I guess the inputs to rotate might need to be relaxed. Could you create an issue in the torchvision repository so that we can track and fix it?

I just opened an issue: Vmap causing TypeError when applied to Rotate · Issue #6524 · pytorch/vision · GitHub
Thanks!