Matrix multipication shapes

Hi there,

I would like to do a matrix multipication which I am not sure of how to implement. It would be an implementation of a doing a different linear forward for every 2D element in the batch.

For every 2D element of shape [seq_len, hidden_in] in the batch I would like to multiply with a specific matrix of shape [hidden_out, hidden_in] to create a batch output of elements of shape [seq_len, hidden_out]. So, the term batch here is a bit misleading as usually you would like to perform the same operation on all elements in the batch, but I would like to do a different operation for each element in the batch, but the same operation along the sequence dimension.

The (batched) shapes are as follows:

Input shape: [batch, seq_len, hidden_in]
Matrix: [batch, hidden_out, hidden_in]
Desired output shape: [batch, seq_len, hidden_out]

It would replace an implementation with linear layers in a loop, something along the lines of:

# weights_for_elements_in_batch is of shape [batch, hidden_out, hidden_in]
# input is of shape [batch, seq_len, hidden_in]
outs = []
for i in range(batch_size):
    out = torch.nn.functional.linear(input[i, :, :], weight=weights_for_elements_in_batch[i, :, :])
output = torch.stack(outs, dim=0) # stack all elements in batch

If anyone has a suggestion, please let me know. Help is appreciated!


Claartje Barkhof


PyTorch provides the Einsum function, which allows you to do different tensor operations based on indices. See here, here, here for more explanations. I find einsum the easiest way to do such operations.

(Also, I would assume that matrix is actually of shape [batch, hidden_in, hidden_out], as you want to multiply 2D slices of shape [seq_len, hidden_in] with 2D slices of shape [hidden_in, hidden_out], and that’s the only way that you could achieve that multiplication).

batch = 8
seq_len = 25
hidden_in = 16
hidden_out = 32

input = torch.randn((batch, seq_len, hidden_in))
matrix = torch.randn((batch, hidden_in, hidden_out))

output = torch.einsum('bsi, bio -> bso', input, matrix).

torch.Size([8, 25, 32]).

The most important element is the formula ‘bsi, bio → bso’, which basically says:

I have two inputs with shape indices (b, s, i), (b, i o) - hopefully you see the meaning of each letter, and an output of shape (b, s, o). I multiply corresponding 2d slices on the dimension b (meaning slices of shape (s, i) with slices of shape (i, o)). These “si, io → so” is basically regular 2D matrix multiplication rule, where you multiply and sum along the dimension i, common to each operand, and get the result of shape (s, o), for each pair in the batch dimension.

1 Like

Thanks a lot @rad . That’s very informative & helpful.

1 Like