How to split a tensor based on a list, then get the average along dim=0?

You can get the tuple by using torch.split, and then define a function:
f = lambda x : torch.mean(x, dim=0), which you can use on your tuple T as
list(map(f,T)), to get the output that you want.

1 Like