PyTorch 1.13: RuntimeError: tensors used as indices must be long, byte or bool tensors

I have the following code that works on PyTorch 1.11 and fails on 1.13

import torch

model = torch.jit.load("foo/")
model = model.eval()
x = torch.jit.load("foo/")
inputs = list(x.parameters())
for i in inputs:
    i.requires_grad = False
print("ALL GOOD")

The error that it fails with is

worker_exedir/cruise/mlp/prediction/pytorch/modules/utils/", line 81, in gather_from_map

    map_idx = torch.stack([batch_indices, y_indices, x_indices], dim=-1)
    map_features = map[map_idx[..., 0], map_idx[..., 1], map_idx[..., 2], :]
                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return map_features
RuntimeError: tensors used as indices must be long, byte or bool tensors

We had the same error in the eager mode and we fixed it.
Is there a way to fix it in a TorchScript without re-training the model?

Hi @Sergei_Vorobev,

What type are batch_indices, y_indices, x_indices? I assume they are torch.float32?

You should be able to fix this via adding map_idx = between the map_idx = ... line and map_features = ... lines.

Thank you @AlphaBetaGamma96 !
I understand how to do this in the python code (eager mode).
Is there a way to do this with existing jit model (from .pt) file?
I’d like to preserve the weights, but update the code…

UPD: I patched the TorchScripted model with @AlphaBetaGamma96 's suggestion directly. You can do that by unzipping the .pt file and finding the corresponding file (code is structured a little differently but it was easy enough to figure out the mapping in my case). Here is a helpful resource on the TorchScript serialization pytorch/ at master · pytorch/pytorch · GitHub

1 Like