You can get the tuple by using torch.split
, and then define a function:
f = lambda x : torch.mean(x, dim=0)
, which you can use on your tuple T
as
list(map(f,T))
, to get the output that you want.
1 Like