I am afraid, I can’t answer that, besides not understanding your question (what’s being reshaped to what?), I also don’t know much about bench-marking.
However, I did a quick test (please see below) after reading List comprehension vs map and indeed, for this case too, having a loop and using list comprehension is faster than using lambda expression and map.
import torch
import timeit
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
M, N, F, K = 50, 30, 30, 10
N = 100
t0 = torch.randint(0,N,(M,N,F), dtype=torch.float)
split = list(np.random.multinomial(M, np.ones(K)/K, size=1)[0])
split_stab = list((np.ones(K)*(M/K)).astype(int))
def list_comp():
return [torch.mean(t, dim=0) for t in torch.split(t0,split)]
def map_lambda():
return list(map(lambda x : torch.mean(x, dim=0), torch.split(t0,split)))
def list_comp_stab():
return [torch.mean(t, dim=0) for t in torch.split(t0,split_stab)]
def map_lambda_stab():
return list(map(lambda x : torch.mean(x, dim=0), torch.split(t0,split_stab)))
number = 10**4
print(timeit.timeit("list_comp()", globals=locals(), number=number))
print(timeit.timeit("map_lambda()", globals=locals(), number=number))
print(timeit.timeit("list_comp_stab()", globals=locals(), number=number))
print(timeit.timeit("map_lambda_stab()", globals=locals(), number=number))