How to split a tensor based on a list, then get the average along dim=0?


Suppose there is a tensor with shape of [M,N,F], and a list of segement size (the sum is M, the length is K). After splitting the tensor into a K-length tuple T, is there a way to get the average of each tensor in the tuple T along dim=0 without using a loop? Since the size of dim=0 for each tensor [K_i, N, F] is different, I could not stack it or apply torch.mean() directly.


You can get the tuple by using torch.split, and then define a function:
f = lambda x : torch.mean(x, dim=0), which you can use on your tuple T as
list(map(f,T)), to get the output that you want.

1 Like

Thanks, one more question regarding the lamba and map function. Will it slow down the tensor operation, compared to the case that each chunk has the same size of dim=0 so the tensor can be reshaped and get the mean?

I am afraid, I can’t answer that, besides not understanding your question (what’s being reshaped to what?), I also don’t know much about bench-marking.

However, I did a quick test (please see below) after reading List comprehension vs map and indeed, for this case too, having a loop and using list comprehension is faster than using lambda expression and map.

import torch
import timeit
import numpy as np


M, N, F, K = 50, 30, 30, 10
N = 100

t0 = torch.randint(0,N,(M,N,F), dtype=torch.float)
split = list(np.random.multinomial(M, np.ones(K)/K, size=1)[0])
split_stab = list((np.ones(K)*(M/K)).astype(int))

def list_comp():
    return [torch.mean(t, dim=0) for t in torch.split(t0,split)]

def map_lambda():
    return list(map(lambda x : torch.mean(x, dim=0), torch.split(t0,split)))

def list_comp_stab():
    return [torch.mean(t, dim=0) for t in torch.split(t0,split_stab)]

def map_lambda_stab():
    return list(map(lambda x : torch.mean(x, dim=0), torch.split(t0,split_stab)))

number = 10**4

print(timeit.timeit("list_comp()", globals=locals(), number=number))
print(timeit.timeit("map_lambda()", globals=locals(), number=number))
print(timeit.timeit("list_comp_stab()", globals=locals(), number=number))
print(timeit.timeit("map_lambda_stab()", globals=locals(), number=number))
1 Like