Hello, I have the following code:
def normalize(data: torch.Tensor, mean: torch.Tensor,
std: torch.Tensor) -> torch.Tensor:
"""Normalise the image with channel-wise mean and standard deviation.
Args:
data (torch.Tensor): The image tensor to be normalised.
mean (torch.Tensor): Mean for each channel.
std (torch.Tensor): Standard deviations for each channel.
Returns:
Tensor: The normalised image tensor.
"""
if not torch.is_tensor(data):
raise TypeError('data should be a tensor. Got {}'.format(type(data)))
if not torch.is_tensor(mean):
raise TypeError('mean should be a tensor. Got {}'.format(type(mean)))
if not torch.is_tensor(std):
raise TypeError('std should be a tensor. Got {}'.format(type(std)))
if len(mean) != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
raise ValueError('mean lenght and number of channels do not match')
if len(std) != data.shape[-3] and std.shape[:2] != data.shape[:2]:
raise ValueError('std lenght and number of channels do not match')
if std.shape != mean.shape:
raise ValueError('std and mean must have the same shape')
mean = mean[..., :, None, None].to(data.device)
std = std[..., :, None, None].to(data.device)
out = data.sub(mean).div(std)
return out
I would like for this function to be able to be executed with and without JIT. The problem I am currently facing is that, since this function has control flow statements (IFs) when I run:
f = image.Normalize(mean, std)
jit_trace = torch.jit.trace(f, data)
jit_trace(data2)
It will run correctly, but if the input changes in such a way that the control flow takes a different branch then it will not raise the desired exception or give the desired output.
How can I make this work? Thanks in advance!