Unexpected behavior from torchscript

Hi,

I have encountered some unexpected behavior with mixing torch.jit.script and torch.jit.trace. Here’s a example to reproduce.

import torch
import numpy as np

@torch.jit.script
def select_rows(
    nums: int,
    data: torch.Tensor,
    size: int
):
    valid_choice = torch.multinomial(torch.ones(nums).float(), size)
    return data[valid_choice]

def do_selection(x):
    return select_rows(x.shape[0], x, x.shape[0])

t_4 = torch.tensor(np.array([1, 2, 3, 4]))
t_7 = torch.tensor(np.array([1, 2, 3, 4, 5, 6, 7]))

traced_selection = torch.jit.trace(do_selection, t_4)

print(traced_selection(t_4))
>>> tensor([3, 1, 2, 4])  # A random arrangement of the input data.

print(traced_selection(t_7))
>>> tensor([1, 3, 2, 4])  # Another random arrangement, but of the TRACED EXAMPLE!
# Expected a random arrangement of the current input of size 7!

In my actual example, def do_selection() is extremely complicated, and cannot be scripted using torch.jit.script . What are my options here? Is this the expected behavior?

Thanks.

Yes, this is expected behavior and the second call will also warn you:

TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!

Mismatched elements: 3 / 4 (75.0%)
Greatest absolute difference: 3.0 at index (3,) (up to 0.0 allowed)
Greatest relative difference: 2.0 at index (1,) (up to 1e-05 allowed)
  _check_trace(

torch.jit.trace will not record any data-dependent control flow and will use the passed inputs to trace the model execution.
You should use torch.jit.script instead and narrow down what exactly is failing in your code using it.

Thanks! I was looking for an easy way out but there wasn’t one. We were attempting to port the ViLT, ICML 2021 model to a phone.

Took a while, but it’s working now.

For anyone that’s interested – https://github.com/priyamtejaswin/ViLT/blob/script/generate_assets.py