Torch.jit.trace unexpected error with `, dim=-1)`

Find below a Minimum Reproducible Example that crashes both in Pytorch 1.1 and Pytorch 1.2 with CUDA (it works with CPU).

import torch 
from torch import nn

device = torch.device('cuda') # crashes with cuda, works with cpu

class Model(nn.Module):

    def __init__(self):
        self.linear1 = nn.Linear(2, 16)
        self.linear2 = nn.Linear(2, 16)

    def forward(self, x, y):
        x = self.linear1(x)
        y = self.linear2(y)
        return[x, y], dim=-1) # if we replace -1 with 1 works fine

model = Model().to(device)

data = [torch.randn(1, 2).to(device), torch.randn(1, 2).to(device)] 

traced = torch.jit.trace(model, data)


Surprisingly the above works with CPU backend but not with CUDA backend. It also works when, dim=1) but crashes with a negative dimension refering to the same one, dim=-1).

Find the jit.trace error below (not very explanatory):

torch.jit.TracingCheckError: Tracing failed sanity checks!
Encountered an exception while running the trace with test inputs.
        vector::_M_range_check: __n (which is 18446744073709551615) >= this->size() (which is 2)
        The above operation failed in interpreter, with the following stack trace:


This is a bug, you can track it in the corresponding GitHub issue, any updates / fixes that go in will get posted there.

As a workaround you can wrap the negative index around manually with something like dim=len(x.shape) + (-1)

Yup sorry! Realised that later and filled an issue with it. x.ndim - 1 should also work.

1 Like