Torch.compile fails when indexing tensor

I have been trying to get my model to work with torch.compile but I have been having consistent failures, and I am not sure why. I am not sure if the entire model will work, but it has been failing around this section where a tensor is spliced at some ground truth index.

I rewrote the section into a smaller part here. I ran this with a fresh torch install last night with conda on the latest stable version.


import torch
import torch.nn as nn

class main_class(nn.Module):

    def __init__(self):
        super(main_class, self).__init__()

        # initialize stuff here
        self.num_labels = 26

    def forward(self, offset_pred, gt_centroid_coords):
        batch_size = 1

        batch_list = []
        x_pos = []
        y_pos = []
        z_pos = []

        for cent_bidx in range(gt_centroid_coords.shape[0]):
            for cent_channel in range(gt_centroid_coords.shape[1]):
                cent = gt_centroid_coords[cent_bidx, cent_channel]

                cent_x, cent_y, cent_z = cent.to(torch.long)

                batch_list.append(cent_bidx)
                x_pos.append(cent_x)
                y_pos.append(cent_y)
                z_pos.append(cent_z)

        offset_pred_perm = offset_pred.permute(0, 2, 3, 4, 1)

        offset_pred = offset_pred_perm[batch_list, x_pos, y_pos, z_pos].view(batch_size, self.num_labels, -1)

        return offset_pred


if __name__ == '__main__':

    saved_dict = torch.load('test_process_compile.pt', map_location='cpu')
    offset_pred = saved_dict['offset_pred']
    gt_centroid_coords = saved_dict['gt_centroid_coords']
    model = main_class().to('cuda:0')
    model = torch.compile(model)
    b = model(offset_pred, gt_centroid_coords)

Here is the error.

Traceback (most recent call last):
  File "/home/klein/.pycharm_helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
  File "/home/klein/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/klein/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/klein/VertDetect/test_compile.py", line 55, in <module>
    b = model(offset_pred, gt_centroid_coords)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
    super().run()
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
    and self.step()
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
    getattr(self, inst.opname)(inst)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 168, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 570, in call_function
    return wrap_fx_proxy(tx, proxy, **options)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1191, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1278, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1376, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1337, in get_fake_value
    return wrap_fake_exception(
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 916, in wrap_fake_exception
    return fn()
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1338, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1410, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/klein/miniconda3/envs/vert_detect_updated/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1397, in run_node
    return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., size=(1, 192, 64, 64, 3), dtype=torch.float16,
           grad_fn=<PermuteBackward0>), ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64)], [FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64)], [FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64)], slice(None, None, None))), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
from user code:
   File "/home/klein/VertDetect/test_compile.py", line 36, in forward
    offset_pred_hold = offset_pred_perm[batch_list, x_pos, y_pos, z_pos, :]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I believe I fixed my own problem. There were two problems here. First, the inputs were not on the GPU (they are in the full model, so not a big issue). The reason, though, is that the indexing tensors are a list of tensors and not a tensor itself.

.
.
.
offset_pred_perm = offset_pred.permute(0, 2, 3, 4, 1)

x_pos = torch.tensor(x_pos, device=offset_pred.device, dtype=torch.long)
y_pos = torch.tensor(y_pos, device=offset_pred.device, dtype=torch.long)
z_pos = torch.tensor(z_pos, device=offset_pred.device, dtype=torch.long)
batch_list= torch.tensor(batch_list, device=offset_pred.device, dtype=torch.long)


offset_pred = offset_pred_perm[batch_list, x_pos, y_pos, z_pos, :].view(batch_size, self.num_labels, -1)