Is it possible to stack multiple transformations/functions in PyTorch into a single function?

Is it possible to stack multiple transformations/functions in PyTorch into a single function? I’m ideally looking for something like this (possibly with more care taken over tensor shapes):

import torch

f_stack = torch.stack([lambda x: x+1, lambda x: x*2, lambda x: x*x])

f_stack(torch.tensor(3))
# >>> torch.Tensor([4, 6, 9])

f_stack(torch.tensor(4))
# >>> torch.Tensor([5, 8, 16])

(see associated Stack Overflow question)

No, you have to apply the functions and then stack the result.
I recall that in the olden days this would be fused into a single kernel by the fuser if using TorchScript, maybe it still is with torch.compile or so.

Best regards

Thomas

Hi @tom , thanks for your reply. Makes sense. I’ll have a small number of functions to apply to different elements from a large batch of data, so I guess the most efficient approach will be for me to use some sort of fancy indexing to extract the sub-tensors I need for each function, and then pass those sub-tensors through the functions separately (vectorising for each function, rather than passing every single element through its own function separately, which would presumably be slower).

Just to clarify, the example I gave for what I’m looking for in the original question was a bit unclear, this example is hopefully clearer:

import torch

f_stack = torch.stack([lambda x: x+1, lambda x: x*2, lambda x: x*x])

f_stack(torch.tensor([0, 0, 0]))
# >>> torch.Tensor([1, 0, 0])

f_stack(torch.tensor([3, 3, 3]))
# >>> torch.Tensor([4, 6, 9])

f_stack(torch.tensor([3, 0, 0]))
# >>> torch.Tensor([4, 0, 0])

I assume the same answer still stands