Thanks for your reply @tom. I have more test recently. For the first reshape step, we found the seleted dimensions to be contracted is important to this step.
For example, I want to select 6 dimensions from a 31-dim tensor to permute and reshape, if the dimensions (22, 24, 26, 18, 21, 0) cost ~48ms, but (29, 0, 8, 25, 26, 27) cost ~25ms (At the end I paste my test code).
I profile to see the slower case is memory bound and include more uncoalesced global memory accessses. But so far I cannot make it clear how to calculate the memory requirement.
import torch
import time
a = torch.randn([2]*31, device="cuda", dtype=torch.complex64)
for i in range(20):
torch.cuda.synchronize()
t0 = time.time()
#a.permute((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 23, 25, 27, 28, 29, 30, 22, 24, 26, 18, 21, 0)).reshape([2**25, 2**6])
a.permute((1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 28, 30, 29, 8, 25, 26, 27, 0)).reshape([2**25, 2**6])
torch.cuda.synchronize()
t1 = time.time()
print(f'Iter#{i} time/sec: ', t1 - t0 )