Torch.jit.trace unexpected error with `torch.cat(..., dim=-1)`

Find below a Minimum Reproducible Example that crashes both in Pytorch 1.1 and Pytorch 1.2 with CUDA (it works with CPU).



import torch 
from torch import nn


device = torch.device('cuda') # crashes with cuda, works with cpu


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(2, 16)
        self.linear2 = nn.Linear(2, 16)

    def forward(self, x, y):
        x = self.linear1(x)
        y = self.linear2(y)
        return torch.cat([x, y], dim=-1) # if we replace -1 with 1 works fine


model = Model().to(device)

data = [torch.randn(1, 2).to(device), torch.randn(1, 2).to(device)] 

traced = torch.jit.trace(model, data)

print(traced)

Surprisingly the above works with CPU backend but not with CUDA backend. It also works when torch.cat(..., dim=1) but crashes with a negative dimension refering to the same one torch.cat(..., dim=-1).

Find the jit.trace error below (not very explanatory):

torch.jit.TracingCheckError: Tracing failed sanity checks!
Encountered an exception while running the trace with test inputs.
Exception:
        vector::_M_range_check: __n (which is 18446744073709551615) >= this->size() (which is 2)
        The above operation failed in interpreter, with the following stack trace:

                               

This is a bug, you can track it in the corresponding GitHub issue, any updates / fixes that go in will get posted there.

As a workaround you can wrap the negative index around manually with something like dim=len(x.shape) + (-1)

Yup sorry! Realised that later and filled an issue with it. x.ndim - 1 should also work.

1 Like