Function type in TorchScript

I’d like to parametrize my torch.jit.script'ed function with a function argument, i.e. whether a Forward-Backward algorithm should use lambda x: torch.max(x, dim = dim).values or torch.logsumexp(x, dim = dim). When I pass it as lambda, TorchScript complains that I’m calling a tensor-typed value which happens because it types the argument as Tensor (despite the fact that it’s initialized to a lambda by default).

Is there any way to tell TorchScript to type a value as a Callable?

Is it possible to template a TorchScript function by an external callable values via some other mechanism?

Should I create an issue about this on GitHub?

TorchScript currently doesn’t support callables as values (we’re working on supporting it in the coming months). The compiler doesn’t look at the default values at all since it’s not always possible to recover a full type from the default, so if types aren’t specified it just assumes everything is a Tensor. To specify otherwise, you can use Python 3 type hints or mypy style type comments (details).

Function attributes on modules is the closest you can get at the moment, so something like

def my_fn(x, some_callable):
    return some_callable(x + 10)

would have to be changed something like to

class M(nn.Module):
    def __init__(self, some_callable):
        self.some_callable = some_callable
    
    def forward(self, x):
        return self.some_callable(x + 10)

torch.jit.script(M(the_function))
torch.jit.script(M(some_other_function))

This still has some big unfortunate limitations.

You can’t store some callables in a ModuleList or a ModuleDict and index that, because TorchScript sees them as modules and cannot subscript them.

In my use case a function complied by torch.jit.scripts decorator is about 30% faster which is very nice performance boost! But I need to give up on its flexibility and freeze the function used as an argument to the decorated function (which is OK for when it comes to deployment where such flexibility is not needed as opposed to running many different experiments in development phase).

So +1 from me to enable Callable as an argument when using torch.jit.scripts so that it can be directly used without any extra effort like in the solution proposed.