import torch
x = torch.randn(2, 3, 4, 5) # [batch_size, n_chunk, input_size_1, input_size_2]
x_input = x.view(-1, 1, 4, 5) # [batch_size * n_chunk, 1, input_size_1, input_size_2]
x_bar = model(x) # model is an CNN auto-encoder, for example
As the example shows, I have x with the shape described in comment; there is n_chunk because I split each input to chunks. I use view in order to reshape it as the input to model which is an CNN autoencoder with batch size = batch_size * n_chunk and 1 channel. So x_bar has the same shape.
My question is, how to reshape x_bar into the shape of [batch_size, input_size_1, input_size_2 * n_chunk] without messing up the data order as it will with view?
To clarify, x is a batch of audio spectrogram which is split into n_chunk chunks, with input_size_1 frequency bands and input_size_2 time frames; which is why I’d like to concatenate chunks along the time axis.
You can use permute to order the dimensions as you wany (which is a rotation of the tensor)
the view function reorder picking first elements in the outer dimensions like
if u have a 5d tensor
it will keep this order
[5,4,3,2,1]
applied to you case
you have to do
x_bar = model(x) # model is an CNN auto-encoder, for example
x_bar = x_bar.view(batch_size , n_chunk, 1, input_size_1, input_size_2 ).permute(0,2,3,1).view(batch_size , n_chunk, 1, input_size_1,-1)
Look at permute cos i didnt paid attention to the exact order
Thanks, your solution gave me the desired shape but the content is still not in the correct order.
I realised that I should have explained it in a more intuitive way; it is basically doing a hstack for each data in a batch. So I manage to do it through:
x = [torch.cat(i.chunk(n_chunk, dim=1), dim=-1) for i in x.chunk(batch_size, dim=0)]
x = torch.cat(x, dim=0).squeeze(1)