TorchScript: indexing question / filling NaNs

Hi!

I’m trying to jit.script-compile a model which uses the t[t != t] = ... trick to fill the nans of a tensor with a default value. However, torchscript does not seem to appreciate this kind of indexing. Does anyone know a solution/workaround?

Thank you!

Here’s a small example plus error message:

 @torch.jit.script 
 def f(): 
     t = torch.tensor(0.) 
     t[t!=t] = 7 
RuntimeError: 
Arguments for call are not valid.
The following operator variants are available:
  
  aten::index_put_(Tensor(a) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> (Tensor(a)):
  Expected a value of type 'Tensor' for argument 'values' but instead found type 'int'.
  
  aten::index_put_(Tensor(a) self, Tensor[] indices, Tensor values, bool accumulate=False) -> (Tensor(a)):
  Expected a value of type 'List[Tensor]' for argument 'indices' but instead found type 'List[Optional[Tensor]]'.

The original call is:
at <ipython-input-21-45d08c8a21f7>:4:5
@torch.jit.script
def f():
    t = torch.tensor(0.)
    t[t!=t] = 7
    ~~~~~~~~~~~ <--- HERE

I’m not sure if it’s the indexing or rather the rhs of the assignment, since the error says:

Expected a value of type ‘Tensor’ for argument ‘values’ but instead found type ‘int’.

This seems to work:

@torch.jit.script 
def f(): 
    t = torch.tensor([0.])
    t[t!=t] = torch.tensor(7.)
    return t

f()

Your interpretation of the error msg makes a lot more sense. :smile:
Thank you!