TorchScript is a statically-typed language, and, in most cases, we’re forced to infer untyped variables to be instances of torch.Tensor
. (One of our projects this half is improving our type inference.) An easy fix is to simply annotate the relevant variables with their correct type.
def forward(self, x: List[torch.Tensor]):
return torch.cat(x, dim=self.dim)