I have a batch of images/observations and I want to apply an affine transformation to each of them as a batch, with angle, shear, translation etc provided as a tensor.
Obviously I could so this with python iteration, but I’m trying to make this as performant as possible. Minimal example of what I’ve tried:
def affine (
img,
angle,
):
return torchvision.transforms.functional.affine(
img=img,
angle=angle,
translate=(0,0),
scale=1,
shear=1,
)
batch_affine = vmap(affine)
batch_size = 10
x = obs(batch_size, 3, 3)
y = batch_affine(
x,
torch.rand((batch_size,), dtype=torch.float),
)
This above example doesn’t work, as functional.affine expects float args. Alas, you can’t use item() in vmap, so I’m not sure what to do.