Lets say I have 3 tensors with shape [B, S, d_0], [B, S, d_1], [B, S, d_2]
.
I concat them and the resulting size is [B, S, d_0+d_1+d_2]
Now I need to perform torch.matmul
with the transposed version of itself resulting in [B, d_0+d_1+d_2, d_0+d_1+d_2]
size tensor.
Then I need to get back only the blocks [B, d_0, d_0], [B, d_1, d_1], [B, d_2, d_2]
Is there an efficient way to do these operations?