Pytorch: slice and stack a matrix along dimension 0

So I want to slice a matrix of size (n2, n2) to n2 of (n, n) matrices stacked along the dimension 0, resulting in a (n2, n, n) tensor. e.g.:

a = torch.arange(1,82).view(9,9) # this is the matrix to work on
b = a.view(3,3,3,3) # note that here n=3
print(b.permute(0,2,1,3))

The result is:

tensor([[[[ 1.,  2.,  3.],
          [10., 11., 12.],
          [19., 20., 21.]],

         [[ 4.,  5.,  6.],
          [13., 14., 15.],
          [22., 23., 24.]],

         [[ 7.,  8.,  9.],
          [16., 17., 18.],
          [25., 26., 27.]]],


        [[[28., 29., 30.],
          [37., 38., 39.],
          [46., 47., 48.]],

         [[31., 32., 33.],
          [40., 41., 42.],
          [49., 50., 51.]],

         [[34., 35., 36.],
          [43., 44., 45.],
          [52., 53., 54.]]],


        [[[55., 56., 57.],
          [64., 65., 66.],
          [73., 74., 75.]],

         [[58., 59., 60.],
          [67., 68., 69.],
          [76., 77., 78.]],

         [[61., 62., 63.],
          [70., 71., 72.],
          [79., 80., 81.]]]])

Almost there, except it’s a (3, 3, 3, 3) tensor, instead I want:

tensor([[[ 1.,  2.,  3.],
          [10., 11., 12.],
          [19., 20., 21.]],

         [[ 4.,  5.,  6.],
          [13., 14., 15.],
          [22., 23., 24.]],

         [[ 7.,  8.,  9.],
          [16., 17., 18.],
          [25., 26., 27.]],

        [[28., 29., 30.],
          [37., 38., 39.],
          [46., 47., 48.]],

         [[31., 32., 33.],
          [40., 41., 42.],
          [49., 50., 51.]],

         [[34., 35., 36.],
          [43., 44., 45.],
          [52., 53., 54.]],

        [[55., 56., 57.],
          [64., 65., 66.],
          [73., 74., 75.]],

         [[58., 59., 60.],
          [67., 68., 69.],
          [76., 77., 78.]],

         [[61., 62., 63.],
          [70., 71., 72.],
          [79., 80., 81.]]])

I can’t figure out how to do this… (view(9, 3, 3) wouldn’t work, it messed up the ordering of elements in my (3, 3) submatrices) On the other hand, if there is any other way do operate on those (3, 3) slices of matrix a (inverse, matrix multiplication, etc.), I sure would like to hear about it…

Problem solved by using .reshape(9, 3, 3). I guess create a new copy is necessary because what I want to do broke the original ordering…