Ordering of elements when using torch.flatten() on 4D arrays

I have a 4D tensor of shape [32,64,64,3] which corresponds to [batch, timeframes, frequency_bins, features] and I do tensor.flatten(start_dim=2). I understand the shape will then transform to [32,64,64*3] --> [batch,timeframes,frequency_bins*features] - but in terms of the actual ordering of the elements within that new flattened dimension of 64*3 are the first 64 indexes relating to what would have been [:,:,:,0] the second 64 [:,:,:,1] and the final 64 [:,:,:,2]

The ordering can probably be best seen in this example:

x = torch.arange(1*4*4*3).view(1, 4, 4, 3)
print(x)
> tensor([[[[ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8],
            [ 9, 10, 11]],

           [[12, 13, 14],
            [15, 16, 17],
            [18, 19, 20],
            [21, 22, 23]],

           [[24, 25, 26],
            [27, 28, 29],
            [30, 31, 32],
            [33, 34, 35]],

           [[36, 37, 38],
            [39, 40, 41],
            [42, 43, 44],
            [45, 46, 47]]]])
y = torch.flatten(x, start_dim=2)
print(y)
> tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
           [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
           [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
           [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]]])

As you can see, the last two dimensions (the “squares” in x) will be flattened to rows in y.