I’m trying to speed up some computations converting them to a single matrix operation. I thought this would be faster because it would allow the computations to be run in parallel, whereas the for loop would perform each computation sequentially, but that was not the case when I tested it. So I’m here to get a better understanding of how this works. Why is it that the batched version is so much slower than the loop?
import timeit
import torch
def test_thing():
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
num_iterations = 1_000
num_modules = 12
num_inputs = 13
batch_size = 16
embed_dims = 512
# Initialize inputs and weights/biases
x = torch.rand([num_inputs, batch_size, embed_dims], device=device)
weight = torch.rand([num_modules, embed_dims, embed_dims], device=device)
bias = torch.rand([num_modules, embed_dims], device=device)
x2 = x.unsqueeze(0)
weight2 = weight.unsqueeze(1).transpose(-2,-1)
bias2 = bias.unsqueeze(1).unsqueeze(2)
# Functions to evaluate
def run_batched():
output = x.unsqueeze(0) @ weight.unsqueeze(1).transpose(-2,-1) + bias.unsqueeze(1).unsqueeze(2)
torch.cuda.synchronize()
return output
def run_batched2():
output = x2 @ weight2 + bias2
torch.cuda.synchronize()
return output
def run_non_batched():
output = [
x @ w.transpose(-2,-1) + b
for w,b in zip(weight, bias)
]
torch.cuda.synchronize()
return torch.stack(output)
# Ensure that they are all computing the same thing
assert torch.allclose(run_batched(), run_non_batched())
assert torch.allclose(run_batched2(), run_non_batched())
# Measure run time
total_time_batched = timeit.Timer(run_batched).timeit(number=num_iterations)
total_time_batched2 = timeit.Timer(run_batched2).timeit(number=num_iterations)
total_time_nonbatched = timeit.Timer(run_non_batched).timeit(number=num_iterations)
print(f"Batch: {total_time_batched}")
print(f"Batch2: {total_time_batched2}")
print(f"NonBatch: {total_time_nonbatched}")
assert total_time_batched2 < total_time_nonbatched
assert total_time_batched < total_time_nonbatched
Console output:
Batch: 3.460433868924156
Batch2: 3.4635026860050857
NonBatch: 0.6445531530771405