Help understanding how this function works

Hi all - Can someone please explain to me what this function does and what a valid input for it would be? lightning-bolts/aggregation.py at bcbbf6ab6c36430946dd8a416ddc7e697e8507fc · Lightning-AI/lightning-bolts · GitHub ?

def mean(res, key):
    # recursive mean for multilevel dicts
    return torch.stack([x[key] if isinstance(x, dict) else mean(x, key) for x in res]).mean()

I’m trying to test it out but it’s clearly not the same as torch.mean() and doesn’t seem to accept tensors natively. I also couldn’t find information about it in the docs. I’m struggling to understand how it’s actually working due to its recursive nature, so an explanation of what is going on would also be helpful. Thanks!

I got some help on the lightning channel and a valid input would be:

t = torch.tensor([[100., 100., 200., 200.]])
x = {"test":t}
mean([x], "test")