Thanks for the sketch.
It looks like the the values for each pixel in all channels should be placed as patches (or windows) into the final image.
Here is a small code example:
B, C, H, W = 2, 16, 4, 4
# Create dummy input with same values in each channel
x = torch.arange(C)[None, :, None, None].repeat(B, 1, H, W)
print(x)
> tensor([[[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]],
[[ 1, 1, 1, 1],
[ 1, 1, 1, 1],
[ 1, 1, 1, 1],
[ 1, 1, 1, 1]],
...
# Permute channel dimension to last position and view as 4x4 windows
x = x.permute(0, 2, 3, 1).view(B, H, W, 4, 4)
print(x)
> tensor([[[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
...
# Permute "window dims" with spatial dims, view as desired output
x = x.permute(0, 1, 3, 2, 4).contiguous().view(B, 1, 4*H, 4*W)
print(x)
> tensor([[[[ 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
[ 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7],
[ 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11],
[12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15],
[ 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
[ 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7],
[ 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11],
[12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15],
...