How to properly concat a Tensor according to different dimensions?

Imagine I have a tensor of shape x = (Batch,…, Heads, Sequence, Dim)
And let’s say I want to do a concat on the last dim but according to the heads, basically output a tensor of shape (Batch, …, Sequence, Heads * Dim), doing a

x.view(*x.shape[:-3], x.shape[-2], -1)

would return something that is not what I want.
I found a trick but it’s bad coding:

x.transpose(-2, -3).reshape(*x.shape[:-2], -1)

Do you have a suggestion ?

EDIT the best I found so far:

x = torch.cat(x.unbind(dim=-3), dim=-1)

I would probably have done it this way:

Assuming x’s shape is (Batch, Heads, Sequence, Dim):

x.permute(0,2,1,3).flatten(start_dim=2,end_dim=3) #not tested

Not sure which is the “best” way though.

EDIT: I might have understood badly. You want is to switch dimensions or what? if so it can be done with permute : x.permute(0,2,1,3).