Is it possible to stack multiple transformations/functions in PyTorch into a single function? I’m ideally looking for something like this (possibly with more care taken over tensor shapes):
No, you have to apply the functions and then stack the result.
I recall that in the olden days this would be fused into a single kernel by the fuser if using TorchScript, maybe it still is with torch.compile or so.
Hi @tom , thanks for your reply. Makes sense. I’ll have a small number of functions to apply to different elements from a large batch of data, so I guess the most efficient approach will be for me to use some sort of fancy indexing to extract the sub-tensors I need for each function, and then pass those sub-tensors through the functions separately (vectorising for each function, rather than passing every single element through its own function separately, which would presumably be slower).