Permute fails At Inference Time on PyTorch-XLA

I have a model where the forward pass has something like the below.

        ............
        .............
        all_feats, _ = self.rnn(sequence_output)
        conved = all_feats.unsqueeze(1)
        conved = F.relu(self.conv(conved))
        conved = conved.squeeze()
        conved = conved.permute(0, 2, 1) #  -- it blows here;
        ...........
        ...........

The traceback XLA throws is weird;

Exception in device=TPU:7: torch_xla/csrc/helpers.cpp:86 : Check failed: min_shape_dim <= dim && dim <= max_shape_dim 

*** Begin stack trace ***
	tensorflow::CurrentStackTrace[abi:cxx11]()
	torch_xla::XlaHelpers::GetCanonicalDimensionIndex(long long, long long)
	torch_xla::XlaHelpers::GetCanonicalDimensionIndices(absl::Span<long long const>, long long)
	torch_xla::XLATensor::permute(torch_xla::XLATensor const&, absl::Span<long long const>)
	torch_xla::AtenXlaType::permute(at::Tensor const&, c10::ArrayRef<long>)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, c10::ArrayRef<long>), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<long> > >, at::Tensor (at::Tensor const&, c10::ArrayRef<long>)>::call(c10::OperatorKernel*, at::Tensor const&, c10::ArrayRef<long>)

  File "<ipython-input-26-582cd081b9b1>", line 67, in forward
    conved = conved.permute(0, 2, 1)

Value out of range (expected to be in range of [-2, 1], but got 2)Exception in device=TPU:4: torch_xla/csrc/helpers.cpp:86 : Check failed: min_shape_dim <= dim && dim <= max_shape_dim 

*** End stack trace ***

NB There’s no error when the model is training; Error happens only at the inference time; I cannot understand why will that happen at the very first place;

I can share the model details further via a private gist or something if it’s needed;

Thanks;

I assume the conved.squeeze() might remove unwanted dimensions during inference, since coved has only two dimensions after the squeeze.
I’m not sure, which dimensions you would like to squeeze in this line of code, but note that all dimensions with size==1 will be removed.
If you are using a batch size of 1 during inference, this dimension will be removed, too.
I would recommend to pass the dim argument to tensor.squeeze to make sure only the desired dims are removed.

1 Like

I’m not sure, which dimensions you would like to squeeze in this line of code, but note that all dimensions with size==1 will be removed.

I will be precise from next time!

Yep that was the bug precisely; I was using a bs of 1 at inference;

Thanks @ptrblck <3;

1 Like