nvFuser fails with "Vectorized dim has to be from a contiguous inner most position"

We have a point cloud vision model that fails to run using torch.jit and nvFuser during the forward pass. Unfortunately I am unable to share the model or code publicly, but I am hoping that I can get some generic guidance that I can investigate further.

I have tested with both PyTorch 1.12 and 1.13 and the same error message appears. Unexpectedly, the difference is that in 1.12 it fails in the backwards pass and can run forwards-only, but in 1.13 I get the same error message during the forwards pass.

In this case, the ScriptModule is created using torch.jit.script, and forward-pre-hooks are removed as they are not JIT compatible.

With PyTorch 1.12 and the default environment settings, torch.jit gives the following error:

/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-12/lib/python3.10/site-packages/torch/autograd/__init__.py:173: UserWarning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484808560/work/torch/csrc/jit/codegen/cuda/manager.cpp:329.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Setting the PYTORCH_NVFUSER_DISABLE gives the following detailed traceback:

Traceback (most recent call last):
  File "/*****/train*****.py", line 262, in <module>
    main(args, runtime_manager)
  File "/*****/train*****.py", line 195, in main
    scaler.scale(loss).backward()
  File "/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-12/lib/python3.10/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-12/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Vectorized dim has to be from a contiguous inner most position: T46_l[ iblockIdx.y292{T8.size[1]}, sbS221{( ceilDiv(1, gridDim.z) )}, iS229{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(T8.size[0], 4) ), blockDim.x) ), 1) ), gridDim.x) )}, iS293{T8.size[3]}, sbblockIdx.z220{gridDim.z}, iblockIdx.x228{gridDim.x}, ithreadIdx.x225{blockDim.x}_p, iUS227{1}, iV223{4} ] ca_pos( 8 )

In PyTorch 1.13 the no-fallback error message is:

Traceback (most recent call last):
  File "/*****/train*****.py", line 263, in <module>
    main(args, runtime_manager)
  File "/*****/train*****.py", line 189, in main
    seg_pred, trans_feat = classifier(*local_module._neighbours(local_module, (points,)))
  File "/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-13/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-13/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-13/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/*****/intel/oneapi/intelpython/latest/envs/*****_pytorch1-13/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Vectorized dim has to be from a contiguous inner most position: T36_l[ iblockIdx.y226{T0.size[1]}, bS204{( ceilDiv(1, gridDim.z) )}, iS212{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i0, 8) ), blockDim.x) ), 1) ), gridDim.x) )}, iS105{i4}, bblockIdx.z203{gridDim.z}, iblockIdx.x211{gridDim.x}, ithreadIdx.x208{blockDim.x}_p, iUS210{1}, iV206{8} ] ca_pos( 8 )

Note that running with the fallback allowed or in forwards-only mode does not increase performance at all compared to standard PyTorch at best, and hinders it at worst, so that it’s actually better to run without torch.jit and nvFuser.

This is running on Windows 11 + WSL2 (Ubuntu 20.04 LTS) on AMD Threadripper + NVIDIA RTX A6000. Automatic mixed precision is enabled in these runs but from memory doesn’t make any difference either way.

I would be grateful for any insights anyone may have.

Would it be possible to get a proxy model, which reproduces the same error in the last PyTorch release?

Thanks @ptrblck, I’ll see what can be done. Hopefully the issue can be isolated although I fear it will require the whole model.

(We have also shared code privately with vendors in the past, if that was an option.)

@ptrblck, I have been able to isolate the problem to the following repro:

import torch
import torch.jit
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, mlp):
        super().__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel, *rest = mlp
        for out_channel in rest:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, points):
        new_points = points
        for conv, bn in zip(self.mlp_convs, self.mlp_bns):
            new_points = F.relu(bn(conv(new_points)))
            #            ↑ fails with relu, leaky_relu, gelu, silu
            #              passes with hardtanh, relu6, elu, selu, celu, rrelu, glu, logsigmoid
        return new_points


if __name__ == '__main__':
    device = torch.device('cuda:0')

    new_points = torch.randn((2, 16, 1, 1024), dtype=torch.float32, device=device)
    mlp = torch.jit.script(MLP([16, 32]).to(device))
    out1 = mlp(new_points)
    out2 = mlp(new_points)
    out3 = mlp(new_points)
    out4 = mlp(new_points)

As noted in the code comment, it seems to be an issue with the activation function—some work and some don’t, particularly linear unit functions.

Should I raise a bug on PyTorch’s GitHub?

Thanks for the code snippet!
To understand the issue correctly: in 1.13.0 you are seeing the UserWarning and the RuntimeError if PYTORCH_NVFUSER_DISABLE=fallback is set?
I’ll forward the issue to the nvFuser team so that they could take a look at it.

Correct, the UserWarning by default, and RuntimeError if fallback is disabled via the environment variable.
Thank you very much, much appreciated!

1 Like

Quick update: the issue is reproducible on our side using the latest nightly, but not in the nvFuser development branch. Once it’s merged into upstream I’ll ping you here so that you can verify the fix.

1 Like