I’m running into an issue trying to vmap over the torchvision rotate function. Rotate() requires an int or float input and does not accept single valued tensors. However vmap requires all batched inputs come as tensors. Additionally, one can’t create a helper function that calls rotate(im, angle.item()), as .item() calls are not allowed in vmap.
I feel like there should be a simple solution to this, but I’m not seeing it. Any help would be appreciated.
from functorch import vmap
from torchvision.transforms.functional import rotate
vrot = vmap(rotate, in_dims=(0,0), out_dims=0)
b, dim = 10, 64
inp_ims = torch.rand((b, dim, dim))
angles = torch.rand((b))
vrot(inp_ims, angles)
This gives: TypeError: Argument angle should be int or float