F.affine_grid in batch mode is slow for high-resolution images

I’m using F.affine_grid to apply transformation matrices of size Bx2x3 to a batch of images Bx3x1024x1024. I noticed that the following code

import time
import torch
import torch.nn.functional as F

batch_size = 128
resolution = 1024
M = torch.eye(3, device='cuda').unsqueeze(0).expand(batch_size, -1, -1)[:, :2, :]
img = torch.zeros(batch_size, 3, resolution, resolution, device='cuda')

total_time = 0
total_iterations = 20

for _ in range(total_iterations):
    start = time.time()
    grid = F.affine_grid(M, img.shape, align_corners=False)
    transformed_img = F.grid_sample(img, grid, align_corners=False)
    torch.cuda.synchronize()
    total_time += time.time() - start

print(total_time/total_iterations)  # 0.2470783233642578 seconds

is significantly slower in execution than computing the grid for each batch sample separately

import time
import torch
import torch.nn.functional as F

batch_size = 128
resolution = 1024
M = torch.eye(3, device='cuda').unsqueeze(0).expand(batch_size, -1, -1)[:, :2, :]
img = torch.zeros(batch_size, 3, resolution, resolution, device='cuda')

total_time = 0
total_iterations = 20

for _ in range(total_iterations):
    start = time.time()
    grid = torch.cat([F.affine_grid(M[i:i+1], [1, *img.shape[1:]], align_corners=False) for i in range(img.shape[0])], dim=0)
    transformed_img = F.grid_sample(img, grid, align_corners=False)
    torch.cuda.synchronize()
    total_time += time.time() - start

print(total_time/total_iterations)  # 0.07158441543579101 seconds

I believe this is because of a bmm operation in the creation of the grid between Bx1024*1024x3 and Bx3x2 matrices but still, creating the grid for each batch sample separately seems counterintuitive.

Since my application doesn’t require this operation to be differentiable, I’m wondering if there’s an alternative way of applying a transformation matrix to an image in batch mode that is fast enough for high-resolution images.