Compute block diagonal of multiplication of two tensors efficiently

Assume that we have two 100 by 30 matrices A and B. I want to compute ten 10 by 10 block diagonals of the matrix AB^T, and transform it as a 10 by 100 matrix by concatenating them. Is there an efficient way to do this?
Currently, I chunked the matrices A and B as ten 10 by 30 matrices and compute A_{i} * B_{i}^T for each chunk 0 <= i < 10, and concatenate it, which uses for loop. I want to do this without using any loops.
Thanks in advance.


First step which is the extracting 10x30 matrices can be done in the same way as we extract patches out of matrices. The proper function for doing it is torch.unfold.

After this, we need to compute ab.T operation over batches which torch.bmm enables us to do the dot product patch-wise.

Finally it is just a reshape to convert each 10x10 matrix to a 100 vector.

# initalize section - testing purposes
a = torch.ones((100, 30))
b = torch.ones((100, 30))
for i in range(10):
    a[i*10:(i+1)*10, :]=i
    b[i*10:(i+1)*10, :]=-i

# main code
a = a.unfold(0, 10, 10).reshape(10, -1, 30)  # torch.Size([10, 10, 30])
b = b.unfold(0, 10, 10)  # torch.Size([10, 30, 10])
result = torch.bmm(a, b)  # torch.Size([10, 10, 10])
result = result.reshape(10, -1)  # torch.Size([10, 100])


edit: if you have any trouble with fold and unfold, this discussion may help.

1 Like

Thanks! I understand what unfold does.
But for a, I think we need to use transpose(-1, -2) instead of reshape(10, -1, 30). Am I right? It seems that the last reshape for result also doesn’t give a right answer.

Yes! you are right about transpose(-1, -2) for a. As I used same constant value for all cells of each patch, the result of transpose and reshape was identical.

Concerning the last reshape, it just works like flatten as we want to merge dim=1 and dim=2 into 1D tensor with size=100. In result we are not transposing any dim with any other and as you mentioned in your question, it is more like concatenating them in the same way they already are. What do you think?