Reshape tensor while keeping data in a given dimension

I would like to reshape my data, which represents a single image, such that the width and height (dims 2,3) are lowered and the batch size increases; effectively making small crops as batches out of an image.
I want a specific dimension, the channels (1) to contain the same data.

I initially wrote a pair of functions to do something like this, but I fear it’s not backprop compatible

def img_to_batch(imgtensor, patch_size: int):
    _, ch, height, width = imgtensor.shape
    assert height%patch_size == 0 and width%patch_size == 0, 'img_to_batch: dims must be dividable by patch_size. {}%{}!=0'.format(imgtensor.shape, patch_size)
    bs = math.ceil(height/patch_size) * math.ceil(width/patch_size)
    btensor = torch.zeros([bs,ch,patch_size, patch_size], device=imgtensor.device, dtype=imgtensor.dtype)
    xstart = ystart = 0
    for i in range(bs):
        btensor[i] = imgtensor[:, :, ystart:ystart+patch_size, xstart:xstart+patch_size]
        xstart += patch_size
        if xstart+patch_size > width:
            xstart = 0
            ystart += patch_size
    return btensor

def batch_to_img(btensor, height: int, width: int, ch=3):
    imgtensor = torch.zeros([1, ch, height, width], device=btensor.device, dtype=btensor.dtype)
    patch_size = btensor.shape[-1]
    xstart = ystart = 0
    for i in range(btensor.size(0)):
        imgtensor[0, :, ystart:ystart+patch_size, xstart:xstart+patch_size] = btensor[i]
        xstart += patch_size
        if xstart+patch_size > width:
            xstart = 0
            ystart += patch_size
    return imgtensor

There is a simple view / reshape function in the torch library, but when I use it the channels do not keep their respective data. eg (I would like the data shown to contain the same elements even if out of order):

>>> imgtens = torch.rand(1,3,4,4)
>>> imgtens[:,0,:,:]
tensor([[[0.6830, 0.2091, 0.8786, 0.6002],
         [0.0325, 0.7217, 0.1479, 0.3478],
         [0.0880, 0.8705, 0.0929, 0.7978],
         [0.7604, 0.2658, 0.3518, 0.1969]]])
>>> reshaped = imgtens.view([4,3,2,2])
>>> reshaped[:,0,:,:]
tensor([[[0.6830, 0.2091],
         [0.8786, 0.6002]],

        [[0.7604, 0.2658],
         [0.3518, 0.1969]],

        [[0.3787, 0.0042],
         [0.3481, 0.2722]],

        [[0.4175, 0.8700],
         [0.1930, 0.7646]]])

I read that the fold/unfold function may be able to help by creating a sliding window.

1 Like

Hi,

Yes it works, actually, in your case you are extracting patches using sliding window then permuting channels.

Here is the code that will do the job for your arbitrary example:

x.unfold(2, 2, 2)[0].unfold(2, 2, 2).contiguous().view(3, -1, 2, 2).permute((1, 0, 2, 3))

But if you have question how it works, here is another post I have explained the calculations:

Bests

1 Like

Thank you @Nikronic !
If I understand this correctly, then the img_to_batch function I wrote above would be simply:

def img_to_batch(imgtensor, patch_size: int):
    _, ch, height, width = imgtensor.shape
    assert height%patch_size == 0 and width%patch_size == 0, (
        'img_to_batch: dims must be dividable by patch_size. {}%{}!=0'.format(
            imgtensor.shape, patch_size))
    return imgtensor.unfold(2,patch_size,patch_size).unfold(
        3,patch_size,patch_size).contiguous().view(
            ch,-1,patch_size,patch_size).permute((1,0,2,3))       

Note that I made a slight modification from the equivalent of “x.unfold(2, 2, 2)[0].unfold(2,” to “x.unfold(2, 2, 2).unfold(3,”, I think that it can then also handle arbitrary batch size (and just like the channels they wouldn’t get mixed up). Correct me if I’m wrong.

I do not think you can do this as tensor.unfold depending on some criteria creates another dimension like .unsqueeze(0) and because of that .view and .permute will no longer be valid due to number of dimension mismatch.

Actually, I have never thought of generalizing fold and unfold method due to its tricky behavior (or at least my bad understanding).

You are right, the following test shows that data from channel 1 can be placed in channel 0

>>> img = torch.rand(4,3,8,8)
>>> img_to_batch(img,4)[4,0]
tensor([[0.4258, 0.3276, 0.8221, 0.6588],
        [0.5438, 0.9239, 0.0490, 0.7193],
        [0.5852, 0.7115, 0.0703, 0.1770],
        [0.4305, 0.4190, 0.2891, 0.0326]])
>>> img[0,1]
tensor([[0.4258, 0.3276, 0.8221, 0.6588, 0.9918, 0.6219, 0.4951, 0.4356],
        [0.5438, 0.9239, 0.0490, 0.7193, 0.6819, 0.0627, 0.0361, 0.3178],
        [0.5852, 0.7115, 0.0703, 0.1770, 0.3855, 0.0666, 0.7337, 0.0240],
        [0.4305, 0.4190, 0.2891, 0.0326, 0.3457, 0.7378, 0.5640, 0.7104],
        [0.3787, 0.2371, 0.4585, 0.6150, 0.7169, 0.6518, 0.4671, 0.1212],
        [0.8061, 0.4295, 0.1194, 0.7166, 0.7526, 0.8067, 0.1612, 0.2812],
        [0.3896, 0.8208, 0.5835, 0.6830, 0.0191, 0.7138, 0.9124, 0.7285],
        [0.0963, 0.4236, 0.2779, 0.8006, 0.1528, 0.5168, 0.6543, 0.7928]])

I think I got this :slight_smile: I moved the channel axis to the front s.t. it would be safe and it passes my manual tests.

def img_to_batch(img, patch_size: int):
    _, ch, height, width = img.shape
    assert height%patch_size == 0 and width%patch_size == 0, (
        'img_to_batch: dims must be dividable by patch_size. {}%{}!=0'.format(
            img.shape, patch_size))
    assert img.dim() == 4
    return img.unfold(2, patch_size, patch_size).unfold(
        3, patch_size, patch_size).transpose(1,0).reshape(
            ch, -1, patch_size, patch_size).transpose(1,0)
1 Like