Strategies for Reshaping Tensors to Meet Specific Requirements: Handling Spatial Continuity in Swin-Transformer Architecture

I am trying to implement the mask_windows in the swin-transformer architecture.

I have a mask tensor that looks like this:

tensor([[0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [3., 3., 3., 3., 4., 4., 5., 5.],
        [3., 3., 3., 3., 4., 4., 5., 5.],
        [6., 6., 6., 6., 7., 7., 8., 8.],
        [6., 6., 6., 6., 7., 7., 8., 8.]])

torch.Size([1, 8, 8, 1])

I want to convert it to have the shape :

torch.Size([4, 4, 4, 1])

which must result from partitioning it into quarters, each of size 4 by 4. This explains the presence of the number 4 as the first index in the desired shape.

My initial attempt was:

windows = x.view(-1, window_size, window_size, C)

However, this approach disrupts the spatial continuity of the quarters.

And the correct way to do this is:

x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)

My question revolves around the right mindset and step-by-step approach for tackling such problems. What is the proper mental framework or checklist to ensure that the output adheres to the specified requirements, aside from trial and error with small-sized tensors?
I understand that this relates to how the view and permute operations work, but I’m struggling to simplify my understanding.

Could you guide me on how to think about such transformation?

To achieve the desired transformation, you should think about it in terms of reshaping and reorganizing the original tensor. Here’s a step-by-step approach:

  1. Understand the Desired Shape: First, understand the desired shape of the output tensor. In this case, it’s [4, 4, 4, 1], which indicates that you want to partition the original tensor into quarters.
  2. Identify the Dimensions to Divide: Determine which dimensions of the original tensor need to be divided to achieve the desired shape. In this case, you want to divide the height and width dimensions (2nd and 3rd dimensions), as indicated by H // window_size and W // window_size.
  3. Use View to Divide Dimensions: Use the view operation to divide the identified dimensions. In this case, you can use view(B, H // window_size, window_size, W // window_size, window_size, C) to create quarters along the height and width dimensions.
  4. Reorganize Dimensions: To ensure spatial continuity within each quarter, you may need to reorganize the dimensions. In your example, you used permute(0, 1, 3, 2, 4, 5) to rearrange dimensions.
  5. Flatten for Final Shape: Finally, use contiguous().view to flatten the tensor into the desired shape -1, window_size, window_size, C.

So, the key mental framework is to:

  • Understand the desired output shape.
  • Identify which dimensions of the original tensor need to be divided.
  • Use view to partition those dimensions.
  • Reorganize dimensions if necessary.
  • Flatten the tensor to achieve the desired final shape.

Some readings about

  1. PyTorch Tutorials
  2. Official PyTorch Documentation
  3. PyTorch Fundamentals by Deeplizard
  4. PyTorch Reshaping and Squeezing Tensors
  5. Understanding PyTorch Transformations: A Guide to Reshaping Tensors