Concatenate unequal size tensors, do operation and then split them back

Lets say I have 3 tensors with shape [B, S, d_0], [B, S, d_1], [B, S, d_2].
I concat them and the resulting size is [B, S, d_0+d_1+d_2]
Now I need to perform torch.matmul with the transposed version of itself resulting in [B, d_0+d_1+d_2, d_0+d_1+d_2] size tensor.
Then I need to get back only the blocks [B, d_0, d_0], [B, d_1, d_1], [B, d_2, d_2]

Is there an efficient way to do these operations?

You could write these operations directly as you’ve described them and let torch.compile try to optimize them. With that being said, Your code would use a few metadata manupilations and should launch two kernels (torch.cat and the matmul), so unsure what torch.compile can optimize.