Imagine I have a tensor of shape x = (Batch,…, Heads, Sequence, Dim)
And let’s say I want to do a concat on the last dim but according to the heads, basically output a tensor of shape (Batch, …, Sequence, Heads * Dim), doing a
x.view(*x.shape[:-3], x.shape[-2], -1)
would return something that is not what I want.
I found a trick but it’s bad coding:
x.transpose(-2, -3).reshape(*x.shape[:-2], -1)
Do you have a suggestion ?
EDIT the best I found so far:
x = torch.cat(x.unbind(dim=-3), dim=-1)