Unexpected behavior from torchscript


I have encountered some unexpected behavior with mixing torch.jit.script and torch.jit.trace. Here’s a example to reproduce.

import torch
import numpy as np

def select_rows(
    nums: int,
    data: torch.Tensor,
    size: int
    valid_choice = torch.multinomial(torch.ones(nums).float(), size)
    return data[valid_choice]

def do_selection(x):
    return select_rows(x.shape[0], x, x.shape[0])

t_4 = torch.tensor(np.array([1, 2, 3, 4]))
t_7 = torch.tensor(np.array([1, 2, 3, 4, 5, 6, 7]))

traced_selection = torch.jit.trace(do_selection, t_4)

>>> tensor([3, 1, 2, 4])  # A random arrangement of the input data.

>>> tensor([1, 3, 2, 4])  # Another random arrangement, but of the TRACED EXAMPLE!
# Expected a random arrangement of the current input of size 7!

In my actual example, def do_selection() is extremely complicated, and cannot be scripted using torch.jit.script . What are my options here? Is this the expected behavior?


Yes, this is expected behavior and the second call will also warn you:

TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!

Mismatched elements: 3 / 4 (75.0%)
Greatest absolute difference: 3.0 at index (3,) (up to 0.0 allowed)
Greatest relative difference: 2.0 at index (1,) (up to 1e-05 allowed)

torch.jit.trace will not record any data-dependent control flow and will use the passed inputs to trace the model execution.
You should use torch.jit.script instead and narrow down what exactly is failing in your code using it.