Faster Alternative to a Loop

Hi,

My model has a bottleneck from this loop

a = torch.zeros(B, N, D1, D2)
for t in range(N):
    Bt = b[:,:t+1,:]              #B,T,D1
    Ct = c[:,:t+1,:]              #B,T,D2
    a[:,t,:,:] = Bt.mT @ Ct       #B,D1,T @ B,T,D2 -> B,1,D1,D2

It takes two 3d tensors - b & c - of shape (B,N,D), does a transpose for rows 0:t, and then outputs - a - a 4d tensor of shape (B, N, D1, D2). But the loop gets really slow (especially in back-propagation) as N gets large. Is there a faster alternative? Some way to do this all with matrix multiplication?

Thanks in advance for your help!!!

FWIW, I eventually do a reshape of a to be (B,N,D1*D2) so solutions in that shape work too!

Did you try to apply torch.compile to your method to check if it would help?

You could try this:

import torch

B = 2
N = 3
D1 = 4
D2 = 5

a = torch.zeros(B, N, D1, D2)
b = torch.rand(B, N, D1)
c = torch.rand(B, N, D2)

def loop(a, b, c): #old loop definition you provided
    N = a.size(1)
    for t in range(N):
        Bt = b[:,:t+1, :]
        Ct = c[:,:t+1, :]
        a[:,t,:,:]=Bt.mT @ Ct
    return a

def parallel(b, c): #parallelized version
    N = b.size(1)
    Bt = b.unsqueeze(1).repeat(1,N,1,1).permute(0,3,2,1).triu().permute(0,3,2,1)
    Ct = c.unsqueeze(1).repeat(1,N,1,1).permute(0,3,2,1).triu().permute(0,3,2,1)
    a = Bt.mT @ Ct
    return a

output_loop = loop(a,b,c)
output_parallel = parallel(b,c)

print(torch.isclose(output_loop, output_parallel))

Cheers

I had not and this is a great reminder that torch.compile is a super useful feature.