Hi,
I have encountered some unexpected behavior with mixing torch.jit.script
and torch.jit.trace
. Here’s a example to reproduce.
import torch
import numpy as np
@torch.jit.script
def select_rows(
nums: int,
data: torch.Tensor,
size: int
):
valid_choice = torch.multinomial(torch.ones(nums).float(), size)
return data[valid_choice]
def do_selection(x):
return select_rows(x.shape[0], x, x.shape[0])
t_4 = torch.tensor(np.array([1, 2, 3, 4]))
t_7 = torch.tensor(np.array([1, 2, 3, 4, 5, 6, 7]))
traced_selection = torch.jit.trace(do_selection, t_4)
print(traced_selection(t_4))
>>> tensor([3, 1, 2, 4]) # A random arrangement of the input data.
print(traced_selection(t_7))
>>> tensor([1, 3, 2, 4]) # Another random arrangement, but of the TRACED EXAMPLE!
# Expected a random arrangement of the current input of size 7!
In my actual example, def do_selection()
is extremely complicated, and cannot be scripted using torch.jit.script
. What are my options here? Is this the expected behavior?
Thanks.