Is there any way to give a type hint to the output of torch.jit._wait
?
The below example fails to compile to torchscript
import torch
from typing import List
def process(i: int) -> int:
return i + 1
@torch.jit.script
def process_many(l: List[int]) -> List[int]:
futs: List[torch.jit.Future] = []
out: List[int] = []
for v in l:
futs.append(torch.jit._fork(process, v))
for f in futs:
out.append(torch.jit._wait(f))
return out
with the error
File "/usr/local/lib/python3.7/site-packages/torch/jit/__init__.py", line 1281, in script
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError:
Unknown type name 'torch.jit.Future':
File "./demo.py", line 10
@torch.jit.script
def process_many(l: List[int]) -> List[int]:
futs: List[torch.jit.Future] = []
~~~~~~~~~~~~~~~~ <--- HERE
out: List[int] = []
However, if i change the for loop to a list comprehension, it compiles fine
@torch.jit.script
def process_many(l: List[int]) -> List[int]:
futs = [torch.jit._fork(process, v) for v in l]
out: List[int] = []
for f in futs:
out.append(torch.jit._wait(f))
return out
# >>> print(process_many([0, 3, 5])
# [1, 4, 6]
What haven’t been able to figure out is if torch.jit.Future
is the correct type annotation for me to be using here. For now, it seems possible to work around this just using a list comprehension, but it could get ugly if the logic becomes more complex.
I’m on 1.4.0, if it makes any difference