Type casting in `torch.jit.script`

For multiple reasons including that my model contains control flow and for better device portability, I need to use scripting instead of tracing to TorchScript compile at least parts and ideally all of my model. Unfortunately, my model contains type conversion as predicted floats need to be converted to int be used as indices into other tensors. How can I type cast a tensor when using the scripting approach to compilation when the type method is not supported?

The ideas I have are to either (A) call traced functions only for the type conversions (and make the model definition code a mess by creating traced callables) or (B) wrap the control flow in @torch.jit.script decorators and make the consumers of my model deal with device portability constraints. Is A the appropriate way to go here or is there a way that is better than A or B?

It would be great if type was a supported method for script compilation.

Here’s an example to reproduce the type conversion error:

from typing import Optional, Tuple

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        self.conv_branch_0 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv_branch_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)

    def forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
        x0: torch.Tensor = self.conv_branch_0(inputs)
        x1: torch.Tensor = self.conv_branch_1(inputs)
        x1 = x1.mean(dim=(-1, -2))
        x1 = x1.type(torch.float16)  # `.type` method is unsupported for script-based compilation
        return x0, x1

def convert(
    model: torch.nn.Module,
    example: Optional[torch.Tensor] = None,
    input_dims: Tuple[int] = (224, 224),
    if example is None:
        example = torch.rand(1, 3, *input_dims)

    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing:
    script_module = torch.jit.trace(model, example)
    # Vs script:
    script_module = torch.jit.script(model)
    # `torch.jit.script(model)` fails with:
    # RuntimeError:
    # Tried to access nonexistent attribute or method 'type' of type 'Tensor'.:
    #   File "torchscript_compile_example.py", line 17
    #         x1: torch.Tensor = self.conv_branch_1(inputs)
    #         x1 = x1.mean(dim=(-1, -2))
    #         x1 = x1.type(torch.float16)
    #              ~~~~~~~ <--- HERE
    #         return x0, x1

    return script_module

def main():
    model = Model()
    model = convert(model)

if __name__ == "__main__":

Thanks in advance for any recommendations here!

.to is a great way to cast the type.
type is one of these odd legacy functions that haven’t been properly deprecated - I would probably not use it in any of my projects.

1 Like

Thanks @tom , that fixes it! Good to know that type is deprecated, especially since it isn’t noted as such in the docs: torch.Tensor — PyTorch 1.8.0 documentation