Why are batched operations slower than a for loop?

Hi Howard!

Good question – I don’t know.

I can reproduce your result. including on a recent nightly build.

It’s interesting to note that a non-loop einsum() version is competitive
with your “non-batched” loop version, and, at least on my system, turns
out to be somewhat faster.

(I see slower, but similar relative timings using just the cpu, with einsum()
again being the fastest version.)

Running a tweaked version of your test code, with einsum() added as
“Batch3,” I get:

1.10.0
10.2
GeForce GTX 1050 Ti
check run_batched: True
check run_batched2: True
check run_batched3: True
Batch: 6.828574229999504
Batch2: 11.512633088001166
Batch3: 0.9353041959984694
NonBatch: 1.3501966660005564

Here is my version of your test script:

import timeit
import torch

def test_thing():
    print (torch.__version__)
    if torch.cuda.is_available():
        print (torch.version.cuda)
        print (torch.cuda.get_device_name())
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    num_iterations = 1_000
    
    num_modules = 12
    num_inputs = 13
    batch_size = 16
    embed_dims = 512
    
    # Initialize inputs and weights/biases
    x = torch.rand([num_inputs, batch_size, embed_dims], device=device)
    weight = torch.rand([num_modules, embed_dims, embed_dims], device=device)
    bias = torch.rand([num_modules, embed_dims], device=device)
    
    x2 = x.unsqueeze(0)
    weight2 = weight.unsqueeze(1).transpose(-2,-1)
    bias2 = bias.unsqueeze(1).unsqueeze(2)
    
    # Functions to evaluate
    def run_batched():
        output = x.unsqueeze(0) @ weight.unsqueeze(1).transpose(-2,-1) + bias.unsqueeze(1).unsqueeze(2)
        return output
    
    def run_batched2():
        output = x2 @ weight2 + bias2
        torch.cuda.synchronize()
        return output
    
    def run_batched3():
        output = torch.einsum ('ijk,lmk -> lijm', x, weight) + bias2
        torch.cuda.synchronize()
        return output
    
    def run_non_batched():
        output = [
            x @ w.transpose(-2,-1) + b                       
            for w,b in zip(weight, bias)                     
        ]                                                    
        torch.cuda.synchronize()
        return torch.stack(output)                           
    
    # Ensure that they are all computing the same thing
    print ('check run_batched:', torch.allclose(run_batched(), run_non_batched()))
    print ('check run_batched2:', torch.allclose(run_batched2(), run_non_batched()))
    print ('check run_batched3:', torch.allclose(run_batched3(), run_non_batched()))
    
    # Measure run time                                                          
    total_time_batched = timeit.Timer(run_batched).timeit(number=num_iterations)  
    total_time_batched2 = timeit.Timer(run_batched2).timeit(number=num_iterations)     
    total_time_batched3 = timeit.Timer(run_batched3).timeit(number=num_iterations)     
    total_time_nonbatched = timeit.Timer(run_non_batched).timeit(number=num_iterations)
    
    print(f"Batch: {total_time_batched}")
    print(f"Batch2: {total_time_batched2}")
    print(f"Batch3: {total_time_batched3}")
    print(f"NonBatch: {total_time_nonbatched}")

test_thing()

For completeness, here are my nightly-build results:

1.12.0.dev20220410
11.3
GeForce GTX 1050 Ti
check run_batched: True
check run_batched2: True
check run_batched3: True
Batch: 6.188123900999926
Batch2: 10.426687854000193
Batch3: 0.92185746800169
NonBatch: 1.3100559570011683

Best.

K. Frank