How to divide a tensor with size MxN to smaller tensors

I am interested in dividing a tensor to subtensors.
eg. if i have tensor A = torch.rand(2,3,8,8) i what to generate tenosr B which has the size of 2x3x16x2x2 and then convert it back to in a way that has the same size as A.

in this case, the dimension of tensor A is divided by k = 4 and we chose subblock size of s=8//k=2 (2x2).

in other words, the subblock are chosen in the following way when the same color means those elements are belonging to the same subblock

im not sure how to do it in a fast way, (the brutforce way is using looops), i think the fold unfold should be faster.

can anyone help me please how i can code it in a fast way? :slight_smile:

1 Like

You can reshape tensor A:

>>> A = torch.rand(2,3,8,8)
>>> B = A.view(2, 3, 16, 2, 2)
>>> B.shape
torch.Size([2, 3, 16, 2, 2])

This wont do the reshaping in a way that i want. the view function will reshape row wise.
here is how it will be if we do the way that you say
image

but i need the first subblock to be like this


[[0.803,0.6546],
[0.2339,0.9017]]

I have a solution but maybe not the most elegant one. First, let’s work with the following tensor for A to make it easier :

>>> A = torch.tensor(np.arange(2*3*64)/10.0).reshape(2, 3, 8, 8)
>>> A.shape
torch.Size([2, 3, 8, 8])
>>> A[0, 0]
tensor([[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
        [0.8000, 0.9000, 1.0000, 1.1000, 1.2000, 1.3000, 1.4000, 1.5000],
        [1.6000, 1.7000, 1.8000, 1.9000, 2.0000, 2.1000, 2.2000, 2.3000],
        [2.4000, 2.5000, 2.6000, 2.7000, 2.8000, 2.9000, 3.0000, 3.1000],
        [3.2000, 3.3000, 3.4000, 3.5000, 3.6000, 3.7000, 3.8000, 3.9000],
        [4.0000, 4.1000, 4.2000, 4.3000, 4.4000, 4.5000, 4.6000, 4.7000],
        [4.8000, 4.9000, 5.0000, 5.1000, 5.2000, 5.3000, 5.4000, 5.5000],
        [5.6000, 5.7000, 5.8000, 5.9000, 6.0000, 6.1000, 6.2000, 6.3000]],
       dtype=torch.float64)

and then we can take B as follows:

>>> temp = A.view(2, 3, 32, 2)
>>> ind  = np.arange(32)
>>> ind2 = ind.reshape(8, 4).T.reshape(16, 2)
>>> B = temp[:, :, ind2, :]
>>> B2.shape
torch.Size([2, 3, 16, 2, 2])


>>> B[0, 0, 0]
tensor([[0.0000, 0.1000],
        [0.8000, 0.9000]], dtype=torch.float64)

So the blocks in B are created as intended. The next 2x2 block is

>>> B[0, 0, 1]
tensor([[1.6000, 1.7000],
        [2.4000, 2.5000]], dtype=torch.float64)

And for the 2x2 block in the second column of A:

>>> B2[0, 0, 4]
tensor([[0.2000, 0.3000],
        [1.0000, 1.1000]], dtype=torch.float64)

thank you :slight_smile: it is a little slow, but still faster than brute force way

No problem!

Try replacing the numpy arrays with torch and see if that makes a difference. Note that if you are doing this in a loop, you can just create the tensor for indices once and reuse that in the loop.

1 Like

Thank you for this solution!

I would like to know how I can extend this 2D block solution for a 3D block? I’ve been trying a couple reshaping sizes but can’t seem to get the same result that you did.

Another potential solution, at least for image tensors

def unstack(tensor: torch.Tensor, dimensions: int):
    """
    Unstacks the provided image tensor into the specified amount of dimensions. 
    Expects a tensor in the shape of bs x c x n x n, where bs is the batch size,
    c is the number of channels. Unpacks the image to the size 
    (n / dimensions) ^2 * bs, c, dimensions, dimensions

    Parameters:
    - tensor (torch.Tensor): Image tensor to unstack
    - dimensions (int): Dimensions to unpack to.

    Returns
    - tensor (torch.Tensor): Unstacked image tensor
    """

    assert tensor.shape[-2] == tensor.shape[-1], "Image tensor must be square" 
    assert tensor.shape[-1] % dimensions == 0, "Cannot unpack the image into that shape"

    # Unfold the tensor to the specified dimensions
    tensor = tensor.unfold(1, 1, 1).unfold(2,dimensions,dimensions).unfold(3,dimensions,dimensions)

    # Reshape the tensor to be the shape
    tensor = tensor.squeeze().permute(2,1,0,3,4).reshape(-1, 3, dimensions,dimensions)

    return tensor