Error with my custom concat class with TorchScript

TorchScript is a statically-typed language, and, in most cases, we’re forced to infer untyped variables to be instances of torch.Tensor. (One of our projects this half is improving our type inference.) An easy fix is to simply annotate the relevant variables with their correct type.

def forward(self, x: List[torch.Tensor]):
    return torch.cat(x, dim=self.dim)
3 Likes