Batched matrix multiplication speed

I am doing a matrix multiplication of two relatively large matrices and changing the batch size to 2 significantly increases the execution time (20 times). I am wondering why that is, and if something can be done about it. Here are the two examples:

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
matrix = torch.randn(1, 2, 10000000, device='cuda:0')
start_event.record()
torch.bmm(matrix, matrix.transpose(1, 2))
end_event.record()
torch.cuda.synchronize()
print(f"Time elapsed: {start_event.elapsed_time(end_event)}")

# Time elapsed: 1.6721919775009155
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
matrix = torch.randn(2, 2, 10000000, device='cuda:0')
start_event.record()
torch.bmm(matrix, matrix.transpose(1, 2))
end_event.record()
torch.cuda.synchronize()
print(f"Time elapsed: {start_event.elapsed_time(end_event)}")

# Time elapsed: 34.95014572143555

Pytorch version: 1.13.1
GPU: NVIDIA GeForce RTX 3090 24GB

1 Like