From channels to tiles

Hi,

I’m new to Pytorch and I was wondering how to write a code that efficiently transforms channels into tiles. So from shape [c, n/(c/2), n/(c/2)] to shape [1, n, n], but respecting the tiles ordering, which .reshape does not do. I’ll explain better with an example:

Let’s assume I have a tensor with shape input.shape=torch.Size([4, 64, 64]), and I would like an output tensor with shape output.shape=torch.Size([1, 128, 128]), but where each channel of the first input tensor is a tile in a quarter of the output tensor, e.g. output[64:, :64]==input[1, 64, 64].

But I don’t want to use assignation or for loops. The code would have to be general enough to take also e.g. an input tensor of shape [16, 32, 32], and output a tensor with size [1, 128, 128], where the 16 channels are ordered as tiles.

Thank you!

To answer this question, it might be helpful if you wrote a naive for-loop version for testing purposes and to clarify how you want to rearrange the underlying data.

Hi @eqy, here’s an example.

Let’s say I have an input of 4 channels all of shape [128, 128]. In this simple example, the channels all look the same. I would like to rearrange each channel like a tile in a mosaic with shape [128x2, 128x2].

output = torch.zeros(128*2, 128*2)
output[:128, :128] = input[0]
output[:128, 128:] = input[1]
output[128:, :128] = input[2]
output[128:, 128:] = input[3]

Since I am a new user I can only embed a single image. So here’s the output I would like. And each of the 4 images was a channel in input.
2b8454a0-bd3f-4a83-b2b3-d9950a9cfcf5

I would like the code to be flexible enough to choose based on input channels, how to rearrange the tiles in the output. Ideally, given input with shape [C, H, W] (with H==W), I’d like the output to have shape [1, HxC/2. WxC/2]. But .reshape or .view do not rearrange the pixels in the way that I’d like them to be.

Thanks for your help!

Here’s an example that more or less achieves what you want I think

import torch
input = torch.randn(4, 128, 128)
output = torch.zeros(128*2, 128*2)
output[:128, :128] = input[0]
output[:128, 128:] = input[1]
output[128:, :128] = input[2]
output[128:, 128:] = input[3]

# steps
# break input channels into 2x2
# tile means outer-channel split is H, inner channel split is W
# contiguous is not necessary unless you want the underlying data order to be changed
output2 = input.reshape(2, 2, 128, 128).permute(0, 2, 1, 3).reshape(128*2, 128*2).contiguous()
print(torch.allclose(output, output2))