How to extract smaller image patches (3D)?

It depends a bit how you would like to feed these data into your model(s).
If you want to create a batch of PET images only and another of CT images, you could use two DataLoaders. On the other hand, if you want to mix the modalities, you could use a single DataLoader.
Could you explain your use case a bit, i.e. how are the modalities used and in which steps?

Thank you very much for your suggestion. I’ve tried to fold the patches back through a loop of nn.Fold. It nearly works, however, there still exist some problems which I cannot figure out. Would you please take a look and find the problem?

def show_tensor(tensor):
#input (C x H x w)
img_array = np.array(tensor, dtype=np.uint8)
img_array = img_array.transpose(1, 2, 0)
cv2.imshow(‘img’, img_array)
cv2.waitKey(0)
cv2.destroyAllWindows()

def fold_3d_official():
kernel_size = (3,3,3)
stride = (3,3,3)
padding = (1,1,1)
dilation = (1,1,1)
img = cv2.imread(’/home/katou2/Pictures/your_name_resize.png’)
img = np.array(img, dtype=np.float32)
img = cv2.resize(img, (112, 112))

img = torch.from_numpy(img)

img_batch = []
for i in range(16):
    img_batch.append(img)

img_batch_tensor = torch.stack(img_batch)

img_batch = img_batch_tensor.permute(3, 0, 1, 2)
img_batch = img_batch.unsqueeze(0)
# show_tensor(img_batch[0, :, 0, :, :])

x = F.pad(img_batch, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]))
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
x = x.permute(0, 1, 5, 6, 7, 2, 3, 4)
x = x.contiguous().view(1, 3*3, -1)
fold_1 = nn.Fold((16, 114*114), kernel_size=(3,1), dilation=(1,1), padding=(1,0), stride=(3,1))
y = fold_1(x)
y = y.contiguous().view(1, 3*16*3*3, -1)
fold_2 = nn.Fold((112, 112), kernel_size=(3,3), dilation=(1,1), padding=(1,1), stride=(3,3))
z = fold_2(y)
z = z.contiguous().view(1, 3, 16, 112, 112)
show_tensor(z[0, :, 1, :, :])

Do you get an error message (and could post it here) or what is not working at the moment?

The code is to rebuild an image from patches extracted by your method via nn.Fold. This code can run without an error message and an image can be rebuilt normally. However, the rebuilt image is a little different from the orginal image which I cannot figure out why.

Could you post the shape of the input tensor (after the permutation etc.) so that we could have a look?

If you have an interest, you could run the following code with an input image path and you will know the difference between the built image and the original image. Thanks.

import cv2
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


def show_tensor(tensor):
    #input (C x H x w)
    img_array = np.array(tensor, dtype=np.uint8)
    img_array = img_array.transpose(1, 2, 0)
    cv2.imshow('img', img_array)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def fold_3d_official(img_path):
    kernel_size = (3,3,3)
    stride = (3,3,3)
    padding = (1,1,1)
    dilation = (1,1,1)
    img = cv2.imread(img_path)
    img = np.array(img, dtype=np.float32)
    img = cv2.resize(img, (112, 112))
    
    img = torch.from_numpy(img)
    
    img_batch = []
    for i in range(16):
        img_batch.append(img)
    
    img_batch_tensor = torch.stack(img_batch)
    
    img_batch = img_batch_tensor.permute(3, 0, 1, 2) # torch.Size([3, 16, 112, 112])
    img_batch = img_batch.unsqueeze(0) # torch.Size([1, 3, 16, 112, 112])
    
    x = F.pad(img_batch, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]))
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    x = x.permute(0, 1, 5, 6, 7, 2, 3, 4) # torch.Size([1, 3, 3, 3, 3, 6, 38, 38])
    
    x = x.contiguous().view(1, 3*3, -1) # torch.Size([1, 9, 77976])
    fold_1 = nn.Fold((16, 114*114), kernel_size=(3,1), dilation=(1,1), padding=(1,0), stride=(3,1))
    y = fold_1(x)
    y = y.contiguous().view(1, 3*16*3*3, -1) # torch.Size([1, 432, 1444])
    fold_2 = nn.Fold((112, 112), kernel_size=(3,3), dilation=(1,1), padding=(1,1), stride=(3,3))
    z = fold_2(y)
    z = z.contiguous().view(1, 3, 16, 112, 112) # torch.Size([1, 3, 16, 112, 112])
    show_tensor(z[0, :, 1, :, :])


if __name__ == "__main__":
    img_path = '/home/katou2/Pictures/your_name_resize.png'
    fold_3d_official(img_path)

I made it finally. Just changing the order of x = x.permute(0, 1, 5, 6, 7, 2, 3, 4) -> x = x.permute(0, 1, 5, 2, 6, 7, 3, 4); Someone wants to fold the 3D patches back could use the following code.

import cv2
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


def show_tensor(tensor):
    #input (C x H x w)
    img_array = np.array(tensor, dtype=np.uint8)
    img_array = img_array.transpose(1, 2, 0)
    cv2.imshow('img', img_array)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def fold_3d_official(img_path):
    kernel_size = (3,3,3)
    stride = (3,3,3)
    padding = (1,1,1)
    dilation = (1,1,1)
    img = cv2.imread(img_path)
    img = np.array(img, dtype=np.float32)
    img = cv2.resize(img, (112, 112))
    
    img = torch.from_numpy(img)

    img_batch = []
    for i in range(16):
        img_batch.append(img)
    
    img_batch_tensor = torch.stack(img_batch)
    
    img_batch = img_batch_tensor.permute(3, 0, 1, 2) # torch.Size([3, 16, 112, 112])
    img_batch = img_batch.unsqueeze(0) # torch.Size([1, 3, 16, 112, 112])
    
    x = F.pad(img_batch, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]))
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4) # torch.Size([1, 3, 3, 3, 3, 6, 38, 38])
    
    x = x.contiguous().view(1, 3*3, -1) # torch.Size([1, 9, 77976])
    fold_1 = nn.Fold((16, 114*114), kernel_size=(3,1), dilation=(1,1), padding=(1,0), stride=(3,1))
    y = fold_1(x)
    # y = y.contiguous().view(1, 3, 16, 114, 114)
    # show_tensor(y[0, :, 1, :, :])

    y = y.contiguous().view(1, 3*16*3*3, -1) # torch.Size([1, 432, 1444])
    fold_2 = nn.Fold((112, 112), kernel_size=(3,3), dilation=(1,1), padding=(1,1), stride=(3,3))
    z = fold_2(y)
    z = z.contiguous().view(1, 3, 16, 112, 112) # torch.Size([1, 3, 16, 112, 112])
    show_tensor(z[0, :, 1, :, :])


if __name__ == "__main__":
    img_path = '/home/katou2/Pictures/your_name_resize.png'
    fold_3d_official(img_path)

@ptrblck I appreciate what you have done indeed. Thank you very much. It indeed helps a lot.

1 Like

Hello, with respect to this particular example,for understanding purpose

If suppose I have an image of size [284,143,143] converted to numpy, with dtype =float32 ndim = 3.
And I want to extract 2D patches out of this. How will the padding for sliding window and symmetric padding change here.

I mean the padding would be likewise in the example and then for unfold we would just do
ret = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride) and not take size[2] here???
If thats the case then it would give as torch.Size([5656, 64, 64]) 2D patches .Is it correct?

If we do not pad for the size[2] then ret = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride) , it will give torch.Size([3030, 64, 64])
Can you explain how this padding affects here and how can I acheive 64 X 64 2D patches from [284,143,143], in this exact scenario

I’m not sure why the padding should change if you are using numpy. Could you explain this question a bit, please?

It depends, what the dimensions represent in your example.
If I remember the original question cirrectly, all 3 dimensions created a volume, so all had to be padded.
Usually you would pad the spatial dimensions for 2D patches.

Could you explain, what the dimensions stand for? I assume dim0 would be the channel dimension, since you would like to create 2D patches?
If so, you could reuse my code for dim1 and dim2 padding.

Hello ,
Sorry for creating confusion. My question is ,
If I have a medical image stored as numpy array of size [284,143,143] where 284 are the number of slices , H, W. dtype = float32 and dim =3.
If I want to extract the 2D patch of size 64* 64 from the image. How will I acheive it using unfold.

’ def extract_patches(img, kernel_size = 64, stride=46):

pad1_left = (img.size(1) // stride * stride + kernel_size) - img.size(1)
pad2_left = (img.size(2) // stride * stride + kernel_size) - img.size(2)

# Calculate symmetric padding

pad1_right = pad1_left // 2 if pad1_left % 2 == 0 else pad1_left // 2 + 1
pad2_right = pad2_left // 2 if pad2_left % 2 == 0 else pad2_left // 2 + 1


pad1_left = pad1_left // 2
pad2_left = pad2_left // 2
x = F.pad(img, (pad2_left, pad2_right, pad1_left, pad1_right))


ret = x.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride).reshape(-1,64,64)

ret = ret.unsqueeze(1)  #add a channel dimension.'

My doubt is will I be performing unfolding on dim 0 as well?

It depends what each patch should contain.
If each 64x64 patch should contain all slices, then you should not unfold dim0.
On the other hand, if each patch should only contain a specific number of slices, you could also unfold dim0. If the patches should contain a single slice, it would probably be easier to just split the output.

Hello,

Can you explain what do you mean by splitting the output

To split dim0 into separate tensors, each with a size of 1, you could use x.split(1, dim=0)
This could be applied after unfolding dim1 and dim2, but it depends on your use case and what the result shape should be.

Hi @ptrblck , I have a similar problem: I also have H x W x D sized 3D images (no batch or channel dimension). And I wish to extract 3x3x3 patches, but overlapping.
I would need to first pad the image everywhere with infinity (because of the subsequent computations), and then do the unfolding, to make sure the number of patches in the end is the same as number of voxels in my original image
.
But I would need the output to be of the size 27 x H x W x D, so that output[:, i, j k] is the flattened 3x3x3 patch centered in the original image around [i,j,k].

Can I do this directly with unfold as well? I suppose I could use reshape on top of it, but I have no clue on how to make sure that the values really end up on the right places after reshaping…

I also checked the view method (I assume that would be faster, considering it’s just another view of the same tensor?) but as far as I understood, view only works for non-overlapping patches… Or did I misunderstand?

I’m not sure I understand the desired output shape correctly.
unfold would create a specific number of patches, where each patch would have the specified kernel shape.
It seems your output should have some H, W, D dimensions, which might not be the kernel shape?
If so, how would these shapes be calculated?

I assume you would like to somehow reshape the dimension containing the patches, so that you could index neighboring patches using i, j, k?

So, if the original image IM size is H x W x D, I pad it all around with 1 pixel of Inf, to get something, let’s call it IM2, sized (H+2) x (W+2) x (D+2).
Then I would like to extract all the 3x3x3 patches from there, with centers on the pixels of original image. (so first patch would extend IM2[0:3, 0:3, 0:3], and it’s center corresponds to the first pixel in IM, IM[0,0,0]).

If I do three unfolds on IM2 one after another, as you have suggested in some answers above, I get something sized H x W x D x 3 x 3 x 3. What I wish is to have the patch dimensions flattened instead, to get H x W x D x 27.

So what I wonder is if the three unfolds + reshape would really give me exactly this, with values at the right places? That is, if the patch centered at i,j,k is [[1,2,3],[4,5,6],[7,8,9]], then in the final H x W x D x 27 array the elements [i,j,k,:] should be [1,2,3,4,5,6,7,8,9].
In addition - is it possible to do this entire thing, or at least some parts, by using .view instead? Because I am working with very large 3D arrays, so reshaping and in any way copying the data should probably be avoided… But I just don’t see how to use view in this sense (to have overlapping patches). And also, if I just use .view(h,w,d,27) on the unfolded array of size HxWxDx3x3x3, it complains that the thing is not contiguous… So I am a bit lost at how to do it efficiently.

This should be the case and this quick test shows the behavior:

x = torch.arange(24*24*24).view(24, 24, 24)

kernel_size = 3
stride = 1
patches = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
patches = patches.contiguous().view(22, 22, 22, -1)
print(patches[0, 0, 0])
print(x[:3, :3, :3])

The unfold op should already be using views without copies, which can be seen in the strides of patches:

print(patches.stride())
> (576, 24, 1, 576, 24, 1)

If any operation requires contiguous memory, the mentioned error will be raised and you would have to trigger the copy via contiguous().
So while you could create the patches using views only, you won’t be able to perform all operations on these patches, if the memory locations overlap and you have to copy the tensor.

Hello,

I’m trying to figure out how to apply your code to overlapping 3D patches, but can’t get it to work. Do you have any idea how the code needs to be changed to work for this example?

Hi all! You can use samplers in TorchIO for all this stuff. You can extract 2D, 3D or 4D patches from medical images randomly (for training) or densely (for testing). Here’s a little snippet: Creating non overlapping patches and reconstructing image back from the patches

There’s also support for overlapping patches.

Hi,
I have create a patch extracter and combiner from torch.nn.functional api

code is below,
note that extract_patches_3d and extract_patches_3ds have same output, the latter is just shorter.

Also note that when combining patches that overlap, the overlapping elements will be summed.

import torch

def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
    return x

def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    d_dim_in = x.shape[2]
    h_dim_in = x.shape[3]
    w_dim_in = x.shape[4]
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, D, H, W)
    x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
    # (B, C, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
    # (B, C * kernel_size[0], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
    # (B, C * kernel_size[0] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
    # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)

    x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
    # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  

    x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    return x



def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
    # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
    # (B, C * kernel_size[0] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, D, H * W)
    
    x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))

Output:

torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
torch.Size([36, 2, 2, 2, 2])
tensor([[[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   1.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  2.,   3.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  4.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   5.],
           [  0.,   9.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  6.,   7.],
           [ 10.,  11.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  8.,   0.],
           [ 12.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  13.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 14.,  15.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 16.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  17.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 18.,  19.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 20.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  21.],
           [  0.,  25.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 22.,  23.],
           [ 26.,  27.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 24.,   0.],
           [ 28.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  29.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 30.,  31.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 32.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  33.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 34.,  35.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 36.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  37.],
           [  0.,  41.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 38.,  39.],
           [ 42.,  43.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 40.,   0.],
           [ 44.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  45.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 46.,  47.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 48.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  49.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 50.,  51.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 52.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  53.],
           [  0.,  57.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 54.,  55.],
           [ 58.,  59.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 56.,   0.],
           [ 60.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  61.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 62.,  63.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 64.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  65.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 66.,  67.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 68.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  69.],
           [  0.,  73.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 70.,  71.],
           [ 74.,  75.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 72.,   0.],
           [ 76.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  77.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 78.,  79.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 80.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  81.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 82.,  83.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 84.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  85.],
           [  0.,  89.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 86.,  87.],
           [ 90.,  91.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 88.,   0.],
           [ 92.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  93.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 94.,  95.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 96.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  97.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 98.,  99.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [100.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0., 101.],
           [  0., 105.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[102., 103.],
           [106., 107.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[104.,   0.],
           [108.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0., 109.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[110., 111.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[112.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0., 113.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [114., 115.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [116.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0., 117.],
           [  0., 121.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[118., 119.],
           [122., 123.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[120.,   0.],
           [124.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0., 125.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[126., 127.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[128.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
tensor(True)