Callable support in Torchscript IR

Hi PyTorch. How can I get jit.script to represent a Callable in the IR?

Specifically, in the example shown below, we use a Tensor.map_ to map the provided functor (in this case to compute euclidean distance between two elements) pointwise on two tensors.
Basically we’re looking for a way to represent aten::map_ in the JIT IR, and have the interpreter receive a function ptr to apply for each element in the operand tensors.

# test1.py

import torch
from typing import Callable

class Map(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor, functor: Callable):
        x.map_(y, functor)
        return x

def euclidean_dist(x: float, y: float) -> float:
    return (x - y) * (x + y)

x = torch.tensor([1.0, 2.0, 3.0, 6.0])
y = torch.tensor([5.0, -1.0, 4.0, 3.5])

model = Map()
print(model(x, y, euclidean_dist))

script = torch.jit.script(model)

This expectedly fails (since Callable is not supported):

RuntimeError: 
Unknown type name 'Callable':
  File "test1.py", line 5
    def forward(self, x: torch.Tensor, y: torch.Tensor, functor: Callable):
                                                                 ~~~~~~~~ <--- HERE
        x.map_(y, functor)
        return x

Tried to use torch.jit.interface to specify the interface contract for the Callable, but I get a different error now:

# test2.py

import torch

@torch.jit.interface
class FunctorInterface(object):
    def forward(self, x: float, y: float) -> float:
        pass

class Map(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor, functor: FunctorInterface):
        x.map_(y, functor.forward)
        return x

class EuclideanDist(torch.nn.Module):
    def forward(self, x: float, y: float) -> float:
        return (x - y) * (x + y)

x = torch.tensor([1.0, 2.0, 3.0, 6.0])
y = torch.tensor([5.0, -1.0, 4.0, 3.5])

model = Map()
print(model(x, y, EuclideanDist()))

script = torch.jit.script(model)

Error:

RuntimeError: 
'Tensor' object has no attribute or method 'map_'.:
  File "test2.py", line 10
    def forward(self, x: torch.Tensor, y: torch.Tensor, functor: FunctorInterface):
        x.map_(y, functor.forward)
        ~~~~~~ <--- HERE
        return x

Modifying test1 slightly, say we could specify the concrete lambda at model instantiation. For example:

# test3.py

import torch
from typing import Callable

class Map(torch.nn.Module):
    def __init__(self, functor: Callable):
        super().__init__()
        self.functor = functor

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        x.map_(y, self.functor)
        return x

def euclidean_dist(x: float, y: float) -> float:
    return (x - y) * (x + y)

x = torch.tensor([1.0, 2.0, 3.0, 6.0])
y = torch.tensor([5.0, -1.0, 4.0, 3.5])

model = Map(euclidean_dist)
print(model(x, y))

script = torch.jit.script(model)

This still fails:

RuntimeError:
'Tensor' object has no attribute or method 'map_'.:
  File "test.py", line 12
    def forward(self, x: torch.Tensor, y: torch.Tensor):
        x.map_(y, self.functor)
        ~~~~~~ <--- HERE
        return x