a = torch.zeros(B, N, D1, D2)
for t in range(N):
Bt = b[:,:t+1,:] #B,T,D1
Ct = c[:,:t+1,:] #B,T,D2
a[:,t,:,:] = Bt.mT @ Ct #B,D1,T @ B,T,D2 -> B,1,D1,D2
It takes two 3d tensors - b & c - of shape (B,N,D), does a transpose for rows 0:t, and then outputs - a - a 4d tensor of shape (B, N, D1, D2). But the loop gets really slow (especially in back-propagation) as N gets large. Is there a faster alternative? Some way to do this all with matrix multiplication?
Thanks in advance for your help!!!
FWIW, I eventually do a reshape of a to be (B,N,D1*D2) so solutions in that shape work too!
import torch
B = 2
N = 3
D1 = 4
D2 = 5
a = torch.zeros(B, N, D1, D2)
b = torch.rand(B, N, D1)
c = torch.rand(B, N, D2)
def loop(a, b, c): #old loop definition you provided
N = a.size(1)
for t in range(N):
Bt = b[:,:t+1, :]
Ct = c[:,:t+1, :]
a[:,t,:,:]=Bt.mT @ Ct
return a
def parallel(b, c): #parallelized version
N = b.size(1)
Bt = b.unsqueeze(1).repeat(1,N,1,1).permute(0,3,2,1).triu().permute(0,3,2,1)
Ct = c.unsqueeze(1).repeat(1,N,1,1).permute(0,3,2,1).triu().permute(0,3,2,1)
a = Bt.mT @ Ct
return a
output_loop = loop(a,b,c)
output_parallel = parallel(b,c)
print(torch.isclose(output_loop, output_parallel))