Compute block diagonal of multiplication of two tensors efficiently

Hi,

First step which is the extracting 10x30 matrices can be done in the same way as we extract patches out of matrices. The proper function for doing it is torch.unfold.

After this, we need to compute ab.T operation over batches which torch.bmm enables us to do the dot product patch-wise.

Finally it is just a reshape to convert each 10x10 matrix to a 100 vector.

# initalize section - testing purposes
a = torch.ones((100, 30))
b = torch.ones((100, 30))
for i in range(10):
    a[i*10:(i+1)*10, :]=i
    b[i*10:(i+1)*10, :]=-i

# main code
a = a.unfold(0, 10, 10).reshape(10, -1, 30)  # torch.Size([10, 10, 30])
b = b.unfold(0, 10, 10)  # torch.Size([10, 30, 10])
result = torch.bmm(a, b)  # torch.Size([10, 10, 10])
result = result.reshape(10, -1)  # torch.Size([10, 100])

bests,
Nik

edit: if you have any trouble with fold and unfold, this discussion may help.

2 Likes