Batched affine transormation

I have a batch of images/observations and I want to apply an affine transformation to each of them as a batch, with angle, shear, translation etc provided as a tensor.
Obviously I could so this with python iteration, but I’m trying to make this as performant as possible. Minimal example of what I’ve tried:

def affine (
    img,
    angle,
):
    return torchvision.transforms.functional.affine(
        img=img,
        angle=angle,
        translate=(0,0),
        scale=1,
        shear=1,
    )
batch_affine = vmap(affine)
batch_size = 10
x = obs(batch_size, 3, 3)
y = batch_affine(
    x,
    torch.rand((batch_size,), dtype=torch.float),
)

This above example doesn’t work, as functional.affine expects float args. Alas, you can’t use item() in vmap, so I’m not sure what to do.

Hi @Ryan_Keathley1,

Does the .item() error occur within torchvision.transforms.functional.affine? If so, open an issue on their github: Issues · pytorch/functorch · GitHub

No, it’s in the vmap call. It gives a specific warning that it’s not supported.

If that function isn’t supported by vmap, you’ll need to open an issue (link above) and they’ll add it to their list of operation to vectorize!