Hey all,
I try to perform a matrix multiplication of a batch of word vectors and specific attention weights. So I have a batch of e.g., 30 document, with 200 words and 512 embedding dimensions. Additionally, I have a matrix of attention weights for 16 different heads (4 in this example), which I want to multiply with every word in my document. The twist is that I want attention head (1-4) to only apply to specific dimensions of my word vectors, e.g., attention head 1 to dimensions 0-127, head 2 to dimensions 128-191 etc…
x = torch.randn(30, 200, 512) # Batch x Words x Embedding
attention = torch.randn(30, 4, 200, 200) # Batch x Heads x Words Weights x Word Weights
The code I already came up with works and is reasonably fast in a forward pass, but incredibly slow in a backward pass.
start_dim = 0
num_heads = 4
head_dimensions = torch.tensor([128, 64, 32, 32])
for i in range(num_heads):
end_dim = (start_dim + head_dimensions[i]).item()
x[:, :, start_dim:end_dim] = torch.matmul(attention[:, i, :, :], x[:, :, start_dim:end_dim])
start_dim = end_dim
Anyone an idea, why this is slow in backward propagation and how it could be solved more elegantly?