For multiple reasons including that my model contains control flow and for better device portability, I need to use scripting instead of tracing to TorchScript compile at least parts and ideally all of my model. Unfortunately, my model contains type conversion as predicted floats need to be converted to int be used as indices into other tensors. How can I type cast a tensor when using the scripting approach to compilation when the type
method is not supported?
The ideas I have are to either (A) call traced functions only for the type conversions (and make the model definition code a mess by creating traced callables) or (B) wrap the control flow in @torch.jit.script
decorators and make the consumers of my model deal with device portability constraints. Is A the appropriate way to go here or is there a way that is better than A or B?
It would be great if type
was a supported method for script compilation.
Here’s an example to reproduce the type conversion error:
from typing import Optional, Tuple
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv_branch_0 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv_branch_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
def forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
x0: torch.Tensor = self.conv_branch_0(inputs)
x1: torch.Tensor = self.conv_branch_1(inputs)
x1 = x1.mean(dim=(-1, -2))
x1 = x1.type(torch.float16) # `.type` method is unsupported for script-based compilation
return x0, x1
def convert(
model: torch.nn.Module,
example: Optional[torch.Tensor] = None,
input_dims: Tuple[int] = (224, 224),
):
if example is None:
example = torch.rand(1, 3, *input_dims)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing:
script_module = torch.jit.trace(model, example)
# Vs script:
script_module = torch.jit.script(model)
# `torch.jit.script(model)` fails with:
# RuntimeError:
# Tried to access nonexistent attribute or method 'type' of type 'Tensor'.:
# File "torchscript_compile_example.py", line 17
# x1: torch.Tensor = self.conv_branch_1(inputs)
# x1 = x1.mean(dim=(-1, -2))
# x1 = x1.type(torch.float16)
# ~~~~~~~ <--- HERE
# return x0, x1
return script_module
def main():
model = Model()
model.eval()
model = convert(model)
model.save("example-model.pt")
if __name__ == "__main__":
main()
Thanks in advance for any recommendations here!