Hi PyTorch. How can I get jit.script
to represent a Callable in the IR?
Specifically, in the example shown below, we use a Tensor.map_
to map the provided functor (in this case to compute euclidean distance between two elements) pointwise on two tensors.
Basically we’re looking for a way to represent aten::map_
in the JIT IR, and have the interpreter receive a function ptr to apply for each element in the operand tensors.
# test1.py
import torch
from typing import Callable
class Map(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor, functor: Callable):
x.map_(y, functor)
return x
def euclidean_dist(x: float, y: float) -> float:
return (x - y) * (x + y)
x = torch.tensor([1.0, 2.0, 3.0, 6.0])
y = torch.tensor([5.0, -1.0, 4.0, 3.5])
model = Map()
print(model(x, y, euclidean_dist))
script = torch.jit.script(model)
This expectedly fails (since Callable is not supported):
RuntimeError:
Unknown type name 'Callable':
File "test1.py", line 5
def forward(self, x: torch.Tensor, y: torch.Tensor, functor: Callable):
~~~~~~~~ <--- HERE
x.map_(y, functor)
return x
Tried to use torch.jit.interface
to specify the interface contract for the Callable, but I get a different error now:
# test2.py
import torch
@torch.jit.interface
class FunctorInterface(object):
def forward(self, x: float, y: float) -> float:
pass
class Map(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor, functor: FunctorInterface):
x.map_(y, functor.forward)
return x
class EuclideanDist(torch.nn.Module):
def forward(self, x: float, y: float) -> float:
return (x - y) * (x + y)
x = torch.tensor([1.0, 2.0, 3.0, 6.0])
y = torch.tensor([5.0, -1.0, 4.0, 3.5])
model = Map()
print(model(x, y, EuclideanDist()))
script = torch.jit.script(model)
Error:
RuntimeError:
'Tensor' object has no attribute or method 'map_'.:
File "test2.py", line 10
def forward(self, x: torch.Tensor, y: torch.Tensor, functor: FunctorInterface):
x.map_(y, functor.forward)
~~~~~~ <--- HERE
return x
Modifying test1 slightly, say we could specify the concrete lambda at model instantiation. For example:
# test3.py
import torch
from typing import Callable
class Map(torch.nn.Module):
def __init__(self, functor: Callable):
super().__init__()
self.functor = functor
def forward(self, x: torch.Tensor, y: torch.Tensor):
x.map_(y, self.functor)
return x
def euclidean_dist(x: float, y: float) -> float:
return (x - y) * (x + y)
x = torch.tensor([1.0, 2.0, 3.0, 6.0])
y = torch.tensor([5.0, -1.0, 4.0, 3.5])
model = Map(euclidean_dist)
print(model(x, y))
script = torch.jit.script(model)
This still fails:
RuntimeError:
'Tensor' object has no attribute or method 'map_'.:
File "test.py", line 12
def forward(self, x: torch.Tensor, y: torch.Tensor):
x.map_(y, self.functor)
~~~~~~ <--- HERE
return x