Hi everyone,
I’ve come across some interesting behavior regarding TorchScript and index ops. The following two functions do the same job. However, upon inspection of the TorchScript code, one can see that in the latter the JIT compiler completely removes an assignment operator: data_sigma[data_sigma < 1e-12].fill_(1.0)
In [13]: @torch.jit.script
...: def fit(data: Tensor):
...: # Reduce all but the last dimension
...: # pylint:disable=unnecessary-comprehension
...: reduction_dims = [i for i in range(data.dim() - 1)]
...: # pylint:enable=unnecessary-comprehension
...:
...: data_mu = torch.mean(data, dim=reduction_dims, keepdim=True)
...: data_sigma = torch.std(data, dim=reduction_dims, keepdim=True)
...: data_sigma = torch.where(data_sigma < 1e-12, torch.ones_like(data_sigma), data_sigma)
...: return data_mu, data_sigma
...:
In [14]: print(fit.code)
def fit(data: Tensor) -> Tuple[Tensor, Tensor]:
reduction_dims = annotate(List[int], [])
for i in range(torch.sub(torch.dim(data), 1)):
_0 = torch.append(reduction_dims, i)
data_mu = torch.mean(data, reduction_dims, True, dtype=None)
data_sigma = torch.std(data, reduction_dims, True, True)
_1 = torch.lt(data_sigma, 9.9999999999999998e-13)
_2 = torch.ones_like(data_sigma, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None)
data_sigma0 = torch.where(_1, _2, data_sigma)
return (data_mu, data_sigma0)
In [15]: @torch.jit.script
...: def fit(data: Tensor):
...: # Reduce all but the last dimension
...: # pylint:disable=unnecessary-comprehension
...: reduction_dims = [i for i in range(data.dim() - 1)]
...: # pylint:enable=unnecessary-comprehension
...:
...: data_mu = torch.mean(data, dim=reduction_dims, keepdim=True)
...: data_sigma = torch.std(data, dim=reduction_dims, keepdim=True)
...: #data_sigma = torch.where(data_sigma < 1e-12, torch.ones_like(data_sigma), data_sigma)
...: data_sigma[data_sigma < 1e-12].fill_(1.0)
...: return data_mu, data_sigma
...:
In [16]: print(fit.code)
def fit(data: Tensor) -> Tuple[Tensor, Tensor]:
reduction_dims = annotate(List[int], [])
for i in range(torch.sub(torch.dim(data), 1)):
_0 = torch.append(reduction_dims, i)
data_mu = torch.mean(data, reduction_dims, True, dtype=None)
data_sigma = torch.std(data, reduction_dims, True, True)
return (data_mu, data_sigma)
Does anyone have a good explanation for this behavior? I’m worried now that the same may be happening in other parts of my code. Such assignments are important when avoiding numerical precision errors.
Cheers,
Ângelo