Pytorch 1.0 Tracer Warning: Converting a tensor to a Python index might cause the trace to be incorrect

I’m converting another model into jit tracer and keep getting a tracer warning for the code below. I’m only passing a function from the original model class.

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if cfg.CUDA:
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
model.py:192: TracerWarning: Converting a tensor to a Python index might 
cause the trace to be incorrect. We can't record the data flow of Python 
values, so this value will be treated as a constant in the future. This 
means that the trace might not generalize to other inputs!
  
eps = torch.FloatTensor(std.size()).normal_()

I already checked the passing argument “std.size()” if it’s an python index data type but it’s torch

torch.Size([2, 100]) <class 'torch.Size'>

While eps is also converted to FloatTensor. It makes me confused what is being converted into a python index and don’t know where to start.

The jit cannot trace through sizes (they’re just tuples, and torch can strictly only trace tensors). In your case, you can use the _like functions.
If you allow me to say so, the coding style of that function is heavily dated. Don’t use Variable or torch.FloatTensor (and neither .data)!
Here would be a modern version:

def reparametrize(mu, logvar):
  std = logvar.mul(0.5).exp_()
  eps = torch.normal(torch.zeros_like(mu))
  return eps.mul(std).add_(mu)

Now, it might be more natural to write eps = torch.èmpty_like_(mu).normal_(), but the jit hicked on that for me.
You’ll have to trace that with check_trace=False because of the random bits. Also you need a torch version post mid-September.

So

fn = torch.jit.trace(reparametrize, (torch.zeros(3), torch.zeros(3)) , check_trace=False)
m = torch.zeros(3, requires_grad=True)
lv = torch.zeros(3, requires_grad=True)
fn(m, lv).sum().backward()
print(m.grad)
print(lv.grad)

Best regards

Thomas

3 Likes