How to calculate multiple linear layer in one pass

Thanks a lot for this elegant solution. I have one more question about batched matul: If instead of (n_head, hidden), the y_i has shape (batch, windows, n_head, hidden), do I need to manually unsqueeze the first two dimensions of the weight tensor and repeat them to match the size of y_i?