Repeated matrix multiplication

I am currently working on implementing Attention for a Variation auto encoder, as implemented in

I am having a hard time identifying an efficient solution for the following.

I have a tensor A with dimension (11000, 11000)
I also have a tensor H with dimension (11000, 20)
I want to populate a tensor C with the following:
The dimension of C should be (11000, 20)
Where the rows in C are:

C[j,:] = A[0,j] * H[0,:] + A[1,j] * H[1,:] + … + A[11000,j] * H[11000,:]

for all j in range(1,11000)

I tried double for-loops but it was tediously slow.

Hope you can provide some help, thanks in advance!

Best Regards, Jonas


You want to mulitply all the values in A by all values in H in a batched manner. And then reduce over the dimension of A:
C = (A.unsqueeze(1) * H.unsqueeze(2)).sum(dim=1).

Thanks for your reply!

I tried it but it doesn’t give the desired output.

A = torch.rand(11000,11000)
H = torch.rand(11000,20)

C = (A.unsqueeze(1) * H.unsqueeze(2)).sum(dim=1)


torch.Size([11000, 11000])

But the desired shape of C was (20, 11000).

Should dim=0 with the example i provided?

Ho sorry, the sum should be over dim=2 :slight_smile:

Thank you so much! Really helped out.

Hi again.

I tried increasing the value 20 (this is my lstm_size in my VAE model) to a higher number and it causes very high usage of memory in the calculation.

Below is the code to reproduce the high memory consumption

A = torch.rand(11000,11000)

H = torch.rand(11000,200)

C = (A.unsqueeze(1) * H.unsqueeze(2)).sum(dim=2)

The above code is an example of how the code is written in my model, which causes the session to crash.

When the model is ran i get the following error message, with the error referring to the calculation of C.

CUDA out of memory. Tried to allocate 180.30 GiB (GPU 0; 15.90 GiB total capacity; 6.81 GiB already allocated; 8.10 GiB free; 7.10 GiB reserved in total by PyTorch)

Is there a way to optimize the calculation of C?

Thanks in advance.

Regards Jonas

Yes the intermediary Tensor will be quite big in this case!

Looking back at it, you can get the same result by doing: A.sum(-1, keepdim=True) * H which will use much less memory.

1 Like