Reduce number of operations in this PyTorch code

I have a input variable ‘fm’ which needs to be transformed into ‘fm1’ and ‘fm3’ to be processed by the system. However, in doing so I create a temporary variable ‘fm2’ which I do not require at all.

Since i want to reduce time and memory of my deep model and make it as fast as possible I suspect it should be possible to directly obtain ‘fm3’ from ‘fm1’ with less number of ‘view()’ and ‘permutations()’ and ‘fm2’ is not really required. Can someone please help me on this?

Alternatively obtaining ‘fm3’ directly from ‘fm’ will also work.

Here is my code:

h=16 # any number expressible as a power of 2
w=32 # any number expressible as a power of 2
fm = torch.arange(1*1*h*w).view(1,1,h,w)
print(fm)

b,c,h,w = fm.size()
R=2 
out_channel = c*(R**2)
out_h = h//R
out_w = w//R
fm1 = fm.view(b, c, out_h, R, out_w, R).permute(0,1,3,5,2,4).contiguous().view(b,out_channel, out_h, out_w)
print(fm1)

b,c,h,w = fm1.size()
r=4 
out_channel = c*(r**2)
out_h = h//r
out_w = w//r
fm2 = fm1.view(b, c, out_h, r, out_w, r).permute(0,1,3,5,2,4).contiguous().view(b,out_channel, out_h, out_w)
# This fm2 is temporary and is not required

b,c,h,w = fm2.size()
G=R**2
fm3 = fm2.view(b, G, c // G, h, w).permute(0, 2, 1, 3, 4).contiguous().view(b, c, h, w)
print(fm3)

The output is as follows (for brevity here h,w=8; R=2 and r=2),

fm: tensor([[[[ 0,  1,  2,  3,  4,  5,  6,  7],
          [ 8,  9, 10, 11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20, 21, 22, 23],
          [24, 25, 26, 27, 28, 29, 30, 31],
          [32, 33, 34, 35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44, 45, 46, 47],
          [48, 49, 50, 51, 52, 53, 54, 55],
          [56, 57, 58, 59, 60, 61, 62, 63]]]])

 fm1: tensor([[[[ 0,  2,  4,  6],
          [16, 18, 20, 22],
          [32, 34, 36, 38],
          [48, 50, 52, 54]],

         [[ 1,  3,  5,  7],
          [17, 19, 21, 23],
          [33, 35, 37, 39],
          [49, 51, 53, 55]],

         [[ 8, 10, 12, 14],
          [24, 26, 28, 30],
          [40, 42, 44, 46],
          [56, 58, 60, 62]],

         [[ 9, 11, 13, 15],
          [25, 27, 29, 31],
          [41, 43, 45, 47],
          [57, 59, 61, 63]]]])

 fm3: tensor([[[[ 0,  4],
          [32, 36]],

         [[ 1,  5],
          [33, 37]],

         [[ 8, 12],
          [40, 44]],

         [[ 9, 13],
          [41, 45]],

         [[ 2,  6],
          [34, 38]],

         [[ 3,  7],
          [35, 39]],

         [[10, 14],
          [42, 46]],

         [[11, 15],
          [43, 47]],

         [[16, 20],
          [48, 52]],

         [[17, 21],
          [49, 53]],

         [[24, 28],
          [56, 60]],

         [[25, 29],
          [57, 61]],

         [[18, 22],
          [50, 54]],

         [[19, 23],
          [51, 55]],

         [[26, 30],
          [58, 62]],

         [[27, 31],
          [59, 63]]]])