Hi there,
I would like to do a matrix multipication which I am not sure of how to implement. It would be an implementation of a doing a different linear forward for every 2D element in the batch.
For every 2D element of shape [seq_len, hidden_in] in the batch I would like to multiply with a specific matrix of shape [hidden_out, hidden_in] to create a batch output of elements of shape [seq_len, hidden_out]. So, the term batch here is a bit misleading as usually you would like to perform the same operation on all elements in the batch, but I would like to do a different operation for each element in the batch, but the same operation along the sequence dimension.
The (batched) shapes are as follows:
Input shape: [batch, seq_len, hidden_in]
Matrix: [batch, hidden_out, hidden_in]
Desired output shape: [batch, seq_len, hidden_out]
It would replace an implementation with linear layers in a loop, something along the lines of:
# weights_for_elements_in_batch is of shape [batch, hidden_out, hidden_in]
# input is of shape [batch, seq_len, hidden_in]
outs = []
for i in range(batch_size):
out = torch.nn.functional.linear(input[i, :, :], weight=weights_for_elements_in_batch[i, :, :])
output = torch.stack(outs, dim=0) # stack all elements in batch
If anyone has a suggestion, please let me know. Help is appreciated!
Cheers,
Claartje Barkhof