Concatenate dimensions in a tensor

import torch
x = torch.randn(2, 3, 4, 5)  #  [batch_size, n_chunk, input_size_1, input_size_2]
x_input = x.view(-1, 1, 4, 5)  # [batch_size * n_chunk, 1, input_size_1, input_size_2]  
x_bar = model(x)  # model is an CNN auto-encoder, for example

As the example shows, I have x with the shape described in comment; there is n_chunk because I split each input to chunks. I use view in order to reshape it as the input to model which is an CNN autoencoder with batch size = batch_size * n_chunk and 1 channel. So x_bar has the same shape.

My question is, how to reshape x_bar into the shape of [batch_size, input_size_1, input_size_2 * n_chunk] without messing up the data order as it will with view?

To clarify, x is a batch of audio spectrogram which is split into n_chunk chunks, with input_size_1 frequency bands and input_size_2 time frames; which is why I’d like to concatenate chunks along the time axis.

1 Like

You can use permute to order the dimensions as you wany (which is a rotation of the tensor)

the view function reorder picking first elements in the outer dimensions like

if u have a 5d tensor
it will keep this order
[5,4,3,2,1]

applied to you case
you have to do

x_bar = model(x)  # model is an CNN auto-encoder, for example
x_bar = x_bar.view(batch_size , n_chunk, 1, input_size_1, input_size_2 ).permute(0,2,3,1).view(batch_size , n_chunk, 1, input_size_1,-1)

Look at permute cos i didnt paid attention to the exact order

Thanks, your solution gave me the desired shape but the content is still not in the correct order.
I realised that I should have explained it in a more intuitive way; it is basically doing a hstack for each data in a batch. So I manage to do it through:

x = [torch.cat(i.chunk(n_chunk, dim=1), dim=-1) for i in x.chunk(batch_size, dim=0)]
x = torch.cat(x, dim=0).squeeze(1)

Could you check this?

It’s a not very good but very graphical explanation about view and permute which may help you to order your data as you expect

In this image, we concat three images on the batch dimension to get the output as shown in the image. I hope this visual representation helps.

First we had three images with each [1, c,h,w]. After concat we got [3,c,h,w]