Getting Unknown type annotation error when JIT saving

When JIT saving “model.pt” of a complex pytorch model with many custom classes, I am encountering the error that pytorch doesn’t know the type annotation of one of those custom classes. In other words, the following code (drastically summarized from original) fails on the seventh line:

import torch
from gan import Generator
from gan.blocks import SpadeBlock

generator = Generator()
generator.load_weights("path/to/weigts")
jitted = torch.jit.script(generator)
torch.jit.save(jitted, "model.pt")

Error:

Traceback (most recent call last):
  File "pth2onnx.py", line 72, in <module>
    to_torch_jit(generator)
  File "pth2onnx.py", line 24, in to_torch_jit
    jitted = torch.jit.script(generator)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/__init__.py", line 1516, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 310, in create_script_module
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 269, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 138, in infer_concrete_type_builder
    sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 269, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 138, in infer_concrete_type_builder
    sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 269, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 126, in infer_concrete_type_builder
    attr_type = infer_type(name, item)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 99, in infer_type
    attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/annotations.py", line 303, in ann_to_type
    raise ValueError("Unknown type annotation: '{}'".format(ann))
ValueError: Unknown type annotation: '<class 'gan.blocks.SpadeBlock'>'

The type it complains about is indeed a class we ourselves have programmed and used in the loaded Generator. I would appreciate pointers on what could cause this or how to investigate this!

I tried the following:

  • explicitly importing SpadeBlock in the script that calls torch.jit.script
  • ensured it inherits from nn.Module (as does Generator)
  • ensured the gan package is installed, using pip install --user -e <directory>

Any ideas? Thanks in advance!

Which version of torch are you using? If it is anything below 1.6, I think you need to script SpadeBlock as well by decorating its definition with @torch.jit.script.

Thanks for responding.

I checked, but I already had torch==1.6.0 installed. I’m using Python 3.6.0, it it matters. Just to be sure, I reinstalled using pip uninstall torch; pip install --user torch==1.6.0 and now the error changed to be even stranger. Now it can’t identify nn.Module!

Traceback (most recent call last):
  File "pth2onnx.py", line 72, in <module>
    to_torch_jit(generator)
  File "pth2onnx.py", line 24, in to_torch_jit
    jitted = torch.jit.script(generator)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/__init__.py", line 1516, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 310, in create_script_module
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 269, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 126, in infer_concrete_type_builder
    attr_type = infer_type(name, item)
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 99, in infer_type
    attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
  File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/annotations.py", line 303, in ann_to_type
    raise ValueError("Unknown type annotation: '{}'".format(ann))
ValueError: Unknown type annotation: '<class 'torch.nn.modules.module.Module'>'

EDIT: After testing some more the difference in unknown type annotation appears to because I used a different generator class in the second export. Is it possible pytorch’s JIT doesn’t support type hints? I’m using those heavily.

Can you post a small example that I can play with to investigate?

I run this exact script and get the Unknown type annotation nn.Module error. It includes a successful export to onnx just as a sanity check that there isn’t something wrong with the model.

import torch
import torch.nn as nn


class DcganGenerator(nn.Module):
    __main: nn.Module

    def __init__(self, shape_originals, shape_targets):
        super().__init__()
        num_channels = shape_originals[0]
        shape_originals  # TODO upsample to this shape
        ngf = 64
        ndf = 64

        self.__main = nn.Sequential(
            nn.ConvTranspose2d(num_channels, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.PReLU(),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.PReLU(),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.PReLU(),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 1, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.PReLU(),
            nn.ConvTranspose2d(ngf, num_channels, 4, 1, 1, bias=False),
            nn.PReLU(),
        )

    def forward(self, inputs):
        return {"generated": self.__main(inputs)}

def load_generator():
    print("Instantiating generator")
    g = DcganGenerator([3, 128, 128], [3, 512, 512])
    print(g)
    return g


def to_torch_jit(generator):
    print("Converting to JIT pytorch")
    try:
        jitted = torch.jit.script(generator)
        torch.jit.save(jitted, "model.pt")

        print("Created model.pt")
    except Exception as e:
        print(f"Failed. {e}")


def to_onnx(generator):
    print("Converting to ONNX")
    try:
        torch.onnx.export(
            generator,
            torch.rand(1, 3, 128, 128),
            "model.onnx",
            export_params=True,
            opset_version=11,
            input_names=["in"],
            output_names=["out"],
        )
        print("Created model.onnx")
    except Exception as e:
        print(f"Failed. {e}")


generator = load_generator()
print()
to_torch_jit(generator)
to_onnx(generator)

Saved as jit-it.py, output:

$ python3 jit-it.py 
Instantiating generator
DcganGenerator(
  (_DcganGenerator__main): Sequential(
    (0): ConvTranspose2d(3, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): PReLU(num_parameters=1)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): PReLU(num_parameters=1)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): PReLU(num_parameters=1)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
    (13): PReLU(num_parameters=1)
  )
)

Converting to JIT pytorch
Failed. Unknown type annotation: '<class 'torch.nn.modules.module.Module'>'
Converting to ONNX
Created model.onnx

Your example works for me if I remove the declaration of __main outside of __init__ (nn.Module is not a supported type annotation) and rename it to _main (probably some special handling for identifiers that begin with two __).

Ohhh. Thanks for helping figure that out! Works here now too.