I’m using F.affine_grid
to apply transformation matrices of size Bx2x3
to a batch of images Bx3x1024x1024
. I noticed that the following code
import time
import torch
import torch.nn.functional as F
batch_size = 128
resolution = 1024
M = torch.eye(3, device='cuda').unsqueeze(0).expand(batch_size, -1, -1)[:, :2, :]
img = torch.zeros(batch_size, 3, resolution, resolution, device='cuda')
total_time = 0
total_iterations = 20
for _ in range(total_iterations):
start = time.time()
grid = F.affine_grid(M, img.shape, align_corners=False)
transformed_img = F.grid_sample(img, grid, align_corners=False)
torch.cuda.synchronize()
total_time += time.time() - start
print(total_time/total_iterations) # 0.2470783233642578 seconds
is significantly slower in execution than computing the grid for each batch sample separately
import time
import torch
import torch.nn.functional as F
batch_size = 128
resolution = 1024
M = torch.eye(3, device='cuda').unsqueeze(0).expand(batch_size, -1, -1)[:, :2, :]
img = torch.zeros(batch_size, 3, resolution, resolution, device='cuda')
total_time = 0
total_iterations = 20
for _ in range(total_iterations):
start = time.time()
grid = torch.cat([F.affine_grid(M[i:i+1], [1, *img.shape[1:]], align_corners=False) for i in range(img.shape[0])], dim=0)
transformed_img = F.grid_sample(img, grid, align_corners=False)
torch.cuda.synchronize()
total_time += time.time() - start
print(total_time/total_iterations) # 0.07158441543579101 seconds
I believe this is because of a bmm
operation in the creation of the grid between Bx1024*1024x3
and Bx3x2
matrices but still, creating the grid for each batch sample separately seems counterintuitive.
Since my application doesn’t require this operation to be differentiable, I’m wondering if there’s an alternative way of applying a transformation matrix to an image in batch mode that is fast enough for high-resolution images.