Extract image patches from image

Disclaimer: I know there are a couple questions already like this, however none of them seems to get my problem covered.

My goal is simply to extract image patches. The following should make it clear. Consider the following sketch.

I want to seperate my image into e.g. 4 patches and have each patch have all the channels in it.
For example given an image 3x8x8 I want to have the patches tensor to be 4x(4 * 4 * 3) → 4x48

Any ideas for a general applicable code?

tensor.unfold should work as seen in this minimal example:

x = torch.randn(3, 8, 8)
out = x.unfold(2, 4, 4).unfold(1, 4, 4)
print(out.shape)
# torch.Size([3, 2, 2, 4, 4])
out = out.permute(1, 2, 0, 3, 4)
print(out.shape)
# torch.Size([2, 2, 3, 4, 4])
out = out.contiguous().view(out.size(0)*out.size(1), -1)
print(out.shape)
# torch.Size([4, 48])

Thanks for the quick reply. Though it gets the shapes correct it sort of doesnt get the ordering of each sample correct.
Consider this code:

>>> x = torch.arange(3* 8* 8).reshape((3,8,8))
>>> p = 4
>>> res = []
>>> for i in range(1, 3):
...     for j in range(1,3):
...             res.append(x[:, p*(i-1):p*i, p*(j-1):p*j].flatten())
>>> res[0]
tensor([  0,   1,   2,   3,   8,   9,  10,  11,  16,  17,  18,  19,  24,  25,
         26,  27,  64,  65,  66,  67,  72,  73,  74,  75,  80,  81,  82,  83,
         88,  89,  90,  91, 128, 129, 130, 131, 136, 137, 138, 139, 144, 145,
        146, 147, 152, 153, 154, 155])

Whereas in yor case it looks like:

>>> out[0]
tensor([  0,   8,  16,  24,   1,   9,  17,  25,   2,  10,  18,  26,   3,  11,
         19,  27,  64,  72,  80,  88,  65,  73,  81,  89,  66,  74,  82,  90,
         67,  75,  83,  91, 128, 136, 144, 152, 129, 137, 145, 153, 130, 138,
        146, 154, 131, 139, 147, 155])

Any idea how to fix this?

I’m not sure, if you would need to extract patches in your expected output is just the flattened tensor.
In my approach the first intermediate output would keep the patches as:

x = torch.arange(3* 8* 8).reshape((3,8,8))
out = x.unfold(2, 4, 4).unfold(1, 4, 4)
out
print(out.shape)
# torch.Size([3, 2, 2, 4, 4])
out = out.permute(1, 2, 0, 3, 4)
print(out)
tensor([[[[[  0,   8,  16,  24],
           [  1,   9,  17,  25],
           [  2,  10,  18,  26],
           [  3,  11,  19,  27]],

          [[ 64,  72,  80,  88],
           [ 65,  73,  81,  89],
           [ 66,  74,  82,  90],
           [ 67,  75,  83,  91]],
           ...

If you want to get the flattened output wouldn’t a view operation just work?

Yes, I guess a view() would work. Could you help with that? I don’t seem to get my desired output :frowning:

I think I would need to see your (slow) loop approach or any other references.
The previous post:

>>> res[0]
tensor([  0,   1,   2,   3,   8,   9,  10,  11,  16,  17,  18,  19,  24,  25,
         26,  27,  64,  65,  66,  67,  72,  73,  74,  75,  80,  81,  82,  83,
         88,  89,  90,  91, 128, 129, 130, 131, 136, 137, 138, 139, 144, 145,
        146, 147, 152, 153, 154, 155])

would just work, if you use:

x = torch.arange(3* 8* 8).reshape((3,8,8))
out = x.view(-1)
print(out)

although your output is missing some values.

I misunderstood your question " If you want to get the flattened output wouldn’t a view operation just work?"
My intention is to first extract the patches and then flatten each patch. The first code which you sent was almost already perfect for my needs, just that it messed up the order of the patch values.

Update:
I guess I found the missing modification to your code:

>>> x = torch.arange(3* 8* 8).reshape((3,8,8))
>>> out = x.unfold(2, 4, 4).unfold(1, 4, 4)
>>> out = torch.transpose(out, 3,4) <---- transposing it gives me my desired output
>>> out = out.permute(1, 2, 0, 3, 4)
>>> out = out.contiguous().view(out.size(0)*out.size(1), -1)
>>> out[0]
tensor([  0,   1,   2,   3,   8,   9,  10,  11,  16,  17,  18,  19,  24,  25,
         26,  27,  64,  65,  66,  67,  72,  73,  74,  75,  80,  81,  82,  83,
         88,  89,  90,  91, 128, 129, 130, 131, 136, 137, 138, 139, 144, 145,
        146, 147, 152, 153, 154, 155])

which is equivalent to my before posted code

>>> res[0]
tensor([  0,   1,   2,   3,   8,   9,  10,  11,  16,  17,  18,  19,  24,  25,
         26,  27,  64,  65,  66,  67,  72,  73,  74,  75,  80,  81,  82,  83,
         88,  89,  90,  91, 128, 129, 130, 131, 136, 137, 138, 139, 144, 145,
        146, 147, 152, 153, 154, 155])

Oh sorry, I’m blind. I didn’t notice that your reference output contained a different ordering.
Good to hear you’ve solved it and sorry for the confusion.

1 Like

No problem holy PyTorch Keanu Reeves!