Why are batched operations slower than a for loop?

I’m trying to speed up some computations converting them to a single matrix operation. I thought this would be faster because it would allow the computations to be run in parallel, whereas the for loop would perform each computation sequentially, but that was not the case when I tested it. So I’m here to get a better understanding of how this works. Why is it that the batched version is so much slower than the loop?

import timeit
import torch

def test_thing():
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    num_iterations = 1_000

    num_modules = 12
    num_inputs = 13
    batch_size = 16
    embed_dims = 512

    # Initialize inputs and weights/biases
    x = torch.rand([num_inputs, batch_size, embed_dims], device=device)
    weight = torch.rand([num_modules, embed_dims, embed_dims], device=device)                                                                    
    bias = torch.rand([num_modules, embed_dims], device=device)                                                                                  
                                                                                                                                                 
    x2 = x.unsqueeze(0)                                                                                                                          
    weight2 = weight.unsqueeze(1).transpose(-2,-1)                                                                                               
    bias2 = bias.unsqueeze(1).unsqueeze(2)                                                                                                       
                                                                                                                                                 
    # Functions to evaluate                                                                                                                      
    def run_batched():                                                                                                                           
        output = x.unsqueeze(0) @ weight.unsqueeze(1).transpose(-2,-1) + bias.unsqueeze(1).unsqueeze(2)                                          
        torch.cuda.synchronize()  
        return output                                                                                                                    
                                                                                                                                         
    def run_batched2():                                                                                                                  
        output = x2 @ weight2 + bias2                                                                                                    
        torch.cuda.synchronize()
        return output                                                                                                                    
                                                                                                                                         
    def run_non_batched():                                                                                                               
        output = [                                                                                                                       
            x @ w.transpose(-2,-1) + b                       
            for w,b in zip(weight, bias)                     
        ]                                                    
        torch.cuda.synchronize()
        return torch.stack(output)                           
                                                             
    # Ensure that they are all computing the same thing      
    assert torch.allclose(run_batched(), run_non_batched())  
    assert torch.allclose(run_batched2(), run_non_batched()) 
                                                             
    # Measure run time                                                          
    total_time_batched = timeit.Timer(run_batched).timeit(number=num_iterations)  
    total_time_batched2 = timeit.Timer(run_batched2).timeit(number=num_iterations)     
    total_time_nonbatched = timeit.Timer(run_non_batched).timeit(number=num_iterations)
                                                             
    print(f"Batch: {total_time_batched}")                    
    print(f"Batch2: {total_time_batched2}")                  
    print(f"NonBatch: {total_time_nonbatched}")              
                                                             
    assert total_time_batched2 < total_time_nonbatched       
    assert total_time_batched < total_time_nonbatched

Console output:

Batch: 3.460433868924156
Batch2: 3.4635026860050857
NonBatch: 0.6445531530771405

Hi Howard!

Good question – I don’t know.

I can reproduce your result. including on a recent nightly build.

It’s interesting to note that a non-loop einsum() version is competitive
with your “non-batched” loop version, and, at least on my system, turns
out to be somewhat faster.

(I see slower, but similar relative timings using just the cpu, with einsum()
again being the fastest version.)

Running a tweaked version of your test code, with einsum() added as
“Batch3,” I get:

1.10.0
10.2
GeForce GTX 1050 Ti
check run_batched: True
check run_batched2: True
check run_batched3: True
Batch: 6.828574229999504
Batch2: 11.512633088001166
Batch3: 0.9353041959984694
NonBatch: 1.3501966660005564

Here is my version of your test script:

import timeit
import torch

def test_thing():
    print (torch.__version__)
    if torch.cuda.is_available():
        print (torch.version.cuda)
        print (torch.cuda.get_device_name())
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    num_iterations = 1_000
    
    num_modules = 12
    num_inputs = 13
    batch_size = 16
    embed_dims = 512
    
    # Initialize inputs and weights/biases
    x = torch.rand([num_inputs, batch_size, embed_dims], device=device)
    weight = torch.rand([num_modules, embed_dims, embed_dims], device=device)
    bias = torch.rand([num_modules, embed_dims], device=device)
    
    x2 = x.unsqueeze(0)
    weight2 = weight.unsqueeze(1).transpose(-2,-1)
    bias2 = bias.unsqueeze(1).unsqueeze(2)
    
    # Functions to evaluate
    def run_batched():
        output = x.unsqueeze(0) @ weight.unsqueeze(1).transpose(-2,-1) + bias.unsqueeze(1).unsqueeze(2)
        return output
    
    def run_batched2():
        output = x2 @ weight2 + bias2
        torch.cuda.synchronize()
        return output
    
    def run_batched3():
        output = torch.einsum ('ijk,lmk -> lijm', x, weight) + bias2
        torch.cuda.synchronize()
        return output
    
    def run_non_batched():
        output = [
            x @ w.transpose(-2,-1) + b                       
            for w,b in zip(weight, bias)                     
        ]                                                    
        torch.cuda.synchronize()
        return torch.stack(output)                           
    
    # Ensure that they are all computing the same thing
    print ('check run_batched:', torch.allclose(run_batched(), run_non_batched()))
    print ('check run_batched2:', torch.allclose(run_batched2(), run_non_batched()))
    print ('check run_batched3:', torch.allclose(run_batched3(), run_non_batched()))
    
    # Measure run time                                                          
    total_time_batched = timeit.Timer(run_batched).timeit(number=num_iterations)  
    total_time_batched2 = timeit.Timer(run_batched2).timeit(number=num_iterations)     
    total_time_batched3 = timeit.Timer(run_batched3).timeit(number=num_iterations)     
    total_time_nonbatched = timeit.Timer(run_non_batched).timeit(number=num_iterations)
    
    print(f"Batch: {total_time_batched}")
    print(f"Batch2: {total_time_batched2}")
    print(f"Batch3: {total_time_batched3}")
    print(f"NonBatch: {total_time_nonbatched}")

test_thing()

For completeness, here are my nightly-build results:

1.12.0.dev20220410
11.3
GeForce GTX 1050 Ti
check run_batched: True
check run_batched2: True
check run_batched3: True
Batch: 6.188123900999926
Batch2: 10.426687854000193
Batch3: 0.92185746800169
NonBatch: 1.3100559570011683

Best.

K. Frank

Thanks KFrank, that’s good to know. I get a significant speed up using einsum on my machine as well.