Hi there,

I would like to do a matrix multipication which I am not sure of how to implement. It would be an implementation of a doing a different linear forward for every 2D element in the batch.

For every 2D element of shape [seq_len, hidden_in] in the batch I would like to multiply with a specific matrix of shape [hidden_out, hidden_in] to create a batch output of elements of shape [seq_len, hidden_out]. So, the term batch here is a bit misleading as usually you would like to perform the same operation on all elements in the batch, but I would like to do a different operation for each element in the batch, but the same operation along the sequence dimension.

The (batched) shapes are as follows:

Input shape: `[batch, seq_len, hidden_in]`

Matrix: `[batch, hidden_out, hidden_in]`

Desired output shape: `[batch, seq_len, hidden_out]`

It would replace an implementation with linear layers in a loop, something along the lines of:

```
# weights_for_elements_in_batch is of shape [batch, hidden_out, hidden_in]
# input is of shape [batch, seq_len, hidden_in]
outs = []
for i in range(batch_size):
out = torch.nn.functional.linear(input[i, :, :], weight=weights_for_elements_in_batch[i, :, :])
output = torch.stack(outs, dim=0) # stack all elements in batch
```

If anyone has a suggestion, please let me know. Help is appreciated!

Cheers,

Claartje Barkhof