Torch.cat throws error for tensor list when compiling with torchscript

Torch.cat throws error for tensor lists when used within torchscript.
Kindly let me know of a fix/workaround.

Here is a minimal example to reproduce the bug.

import torch
import torch.nn as nn

"""
Smallest working bug for torch.cat torchscript
"""


class Model(nn.Module):
    """dummy model for showing error"""

    def __init__(self):
        super(Model, self).__init__()
        pass

    def forward(self):
        a = torch.rand([6, 1, 12])
        b = torch.rand([6, 1, 12])
        out = torch.cat([a, b], axis=2)
        return out


if __name__ == '__main__':
    model = Model()
    print(model())  # works
    torch.jit.script(model)  # throws error

This code throws the following error:
File "/home/anil/.conda/envs/rnn/lib/python3.7/site-packages/torch/jit/__init__.py", line 1423, in _create_methods_from_stubs
    self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError: 
Arguments for call are not valid.
The following operator variants are available:
  
  aten::cat(Tensor[] tensors, int dim=0) -> (Tensor):
  Keyword argument axis unknown.
  
  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!)):
  Argument out not provided.

The original call is:
at smallest_working_bug_torch_cat_torchscript.py:19:14
    def forward(self):
        a = torch.rand([6, 1, 12])
        b = torch.rand([6, 1, 12])
        out = torch.cat([a, b], axis=2)
              ~~~~~~~~~ <--- HERE
        return out

Thank you for your consideration

PyTorch supports the axis keyword arg for numpy compatibility, but it looks like there is a bug where this isn’t translating into TorchScript. Most TorchScript ops use dim in place of axis (the meaning is the same), so if you change that in your code it should work, i.e. torch.cat([a, b], axis=2) becomes torch.cat([a, b], dim=2).

1 Like