Variable length matrix multiply leads to slow backward pass

Hey all,
I try to perform a matrix multiplication of a batch of word vectors and specific attention weights. So I have a batch of e.g., 30 document, with 200 words and 512 embedding dimensions. Additionally, I have a matrix of attention weights for 16 different heads (4 in this example), which I want to multiply with every word in my document. The twist is that I want attention head (1-4) to only apply to specific dimensions of my word vectors, e.g., attention head 1 to dimensions 0-127, head 2 to dimensions 128-191 etc…

x = torch.randn(30, 200, 512)                   # Batch x Words x Embedding
attention = torch.randn(30, 4, 200, 200)        # Batch x Heads x Words Weights x Word Weights

The code I already came up with works and is reasonably fast in a forward pass, but incredibly slow in a backward pass.

start_dim = 0
num_heads = 4
head_dimensions = torch.tensor([128, 64, 32, 32])

for i in range(num_heads):
    end_dim = (start_dim + head_dimensions[i]).item()

    x[:, :, start_dim:end_dim] = torch.matmul(attention[:, i, :, :], x[:, :, start_dim:end_dim])

    start_dim = end_dim

Anyone an idea, why this is slow in backward propagation and how it could be solved more elegantly?