Making Jit work

Hello, I have the following code:

def normalize(data: torch.Tensor, mean: torch.Tensor,
              std: torch.Tensor) -> torch.Tensor:

    """Normalise the image with channel-wise mean and standard deviation.

    Args:
        data (torch.Tensor): The image tensor to be normalised.
        mean (torch.Tensor): Mean for each channel.
        std (torch.Tensor): Standard deviations for each channel.

        Returns:
        Tensor: The normalised image tensor.
    """

    if not torch.is_tensor(data):
        raise TypeError('data should be a tensor. Got {}'.format(type(data)))

    if not torch.is_tensor(mean):
        raise TypeError('mean should be a tensor. Got {}'.format(type(mean)))

    if not torch.is_tensor(std):
        raise TypeError('std should be a tensor. Got {}'.format(type(std)))

    if len(mean) != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
        raise ValueError('mean lenght and number of channels do not match')

    if len(std) != data.shape[-3] and std.shape[:2] != data.shape[:2]:
        raise ValueError('std lenght and number of channels do not match')

    if std.shape != mean.shape:
        raise ValueError('std and mean must have the same shape')

    mean = mean[..., :, None, None].to(data.device)
    std = std[..., :, None, None].to(data.device)

    out = data.sub(mean).div(std)

    return out

I would like for this function to be able to be executed with and without JIT. The problem I am currently facing is that, since this function has control flow statements (IFs) when I run:

        f = image.Normalize(mean, std)
        jit_trace = torch.jit.trace(f, data)
        jit_trace(data2)

It will run correctly, but if the input changes in such a way that the control flow takes a different branch then it will not raise the desired exception or give the desired output.

How can I make this work? Thanks in advance!

If you want to use data dependent control flow in TorchScript, you need to use script mode, see the docs here under “Scripting:” for details/examples.

For your model in particular it looks like it would work to add @torch.jit.script to normalize() and remove the call to torch.jit.trace.

Hi, If I do that is there a way to call normalize() without jit. I would like to have to choice of running it with and without jit, just in case I would like to add a not jittable operation in the future

There is a global environment variable PYTORCH_JIT=0 that you can use to switch off JIT entirely, if you want to run the jitted parts with regular eager pytorch.

https://pytorch.org/docs/stable/jit.html?highlight=pytorch_jit#envvar-PYTORCH_JIT=1

1 Like

Thanks for the replies. However I cannot find the environment variable PYTORCH_JIT in my system. To search for environment variables I use the following code:

for key in os.environ.keys():
      if "py" in key.lower():
          print(key)

Which yields:
CONDA_PYTHON_EXE
PYTHONPATH

Nothing Pytorch related is shown.

On a second note. The @torch.script.jit runs the jitted code and throws the following error.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/diego/Projects/torchgeometry/.dev_env/lib/python3.7/site-packages/torch/nn/mo
dules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/diego/Projects/torchgeometry/torchgeometry/image/normalization.py", line 28,
in forward
    return normalize(input, self.mean, self.std)
RuntimeError:
Dimension out of range (expected to be in range of [-1, 0], but got -3) (maybe_wrap_dim at
/opt/conda/conda-bld/pytorch_1556653215914/work/c10/core/WrapDimMinimal.h:20)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7ff3515d7dc5
 in /home/diego/Projects/torchgeometry/.dev_env/lib/python3.7/site-packages/torch/lib/libc1
0.so)
frame #1: <unknown function> + 0x7d4a40 (0x7ff351fc5a40 in /home/diego/Projects/torchgeomet
ry/.dev_env/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #2: at::native::slice(at::Tensor const&, long, long, long, long) + 0x4e (0x7ff351fc63
8e in /home/diego/Projects/torchgeometry/.dev_env/lib/python3.7/site-packages/torch/lib/lib
caffe2.so)
frame #3: at::TypeDefault::slice(at::Tensor const&, long, long, long, long) const + 0x1a (0
x7ff3522255fa in /home/diego/Projects/torchgeometry/.dev_env/lib/python3.7/site-packages/to
rch/lib/libcaffe2.so)
frame #4: torch::autograd::VariableType::slice(at::Tensor const&, long, long, long, long) c
onst + 0x6d3 (0x7ff34a32ea23 in /home/diego/Projects/torchgeometry/.dev_env/lib/python3.7/s
ite-packages/torch/lib/libtorch.so.1)
frame #5: <unknown function> + 0x985117 (0x7ff34a4ec117 in /home/diego/Projects/torchgeomet
ry/.dev_env/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: <unknown function> + 0xa73df8 (0x7ff34a5dadf8 in /home/diego/Projects/torchgeomet
ry/.dev_env/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #7: torch::jit::InterpreterState::run(std::vector<c10::IValue, std::allocator<c10::IV
alue> >&) + 0x22 (0x7ff34a5d6372 in /home/diego/Projects/torchgeometry/.dev_env/lib/python3
.7/site-packages/torch/lib/libtorch.so.1)
frame #8: <unknown function> + 0xa5b2d9 (0x7ff34a5c22d9 in /home/diego/Projects/torchgeomet
ry/.dev_env/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #9: <unknown function> + 0x458bf3 (0x7ff3776d0bf3 in /home/diego/Projects/torchgeomet
ry/.dev_env/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x12d07a (0x7ff3773a507a in /home/diego/Projects/torchgeome
try/.dev_env/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #37: __libc_start_main + 0xe7 (0x7ff384563b97 in /lib/x86_64-linux-gnu/libc.so.6)
:
operation failed in interpreter:
        raise ValueError('mean lenght and number of channels do not match')

    if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]:
        raise ValueError('std lenght and number of channels do not match')

    if std.shape != mean.shape:
        raise ValueError('std and mean must have the same shape')

    '''
    mean = mean[..., :, None, None].to(data.device)
           ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    std = std[..., :, None, None].to(data.device)

    out = (data - mean) / std

    return out

What am I doing wrong?

I am using Pytorch 1.1 on an Anaconda envrionment

You have to set it yourself, so if your code is in model.py, running

$ PYTORCH_JIT=0 python model.py

would set the variable (assuming you’re using bash).

As for your code in script, I think it may have something to do with your input shapes, this snippet below works for me (TorchScript is statically typed so the is_tensor checks would also evaluate to true, so they’re removed). It’s slicing data.shape[-3] implies that the inputs need to be at least 4 dimensions.

@torch.jit.script
def normalize(data: torch.Tensor, mean: torch.Tensor,
              std: torch.Tensor) -> torch.Tensor:
    if len(mean) != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
        raise ValueError('mean lenght and number of channels do not match')

    if len(std) != data.shape[-3] and std.shape[:2] != data.shape[:2]:
        raise ValueError('std lenght and number of channels do not match')

    if std.shape != mean.shape:
        raise ValueError('std and mean must have the same shape')

    mean = mean[..., :, None, None].to(data.device)
    std = std[..., :, None, None].to(data.device)

    out = data.sub(mean).div(std)

    return out

print(normalize(torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2)))

I tried executing the script with $PYTORCH_JIT=0 as you said and got the following error:

As for the runtimeerror. I don’t think there is anything wrong with the code, since it will run in native Pytorch. I also had to comment the istensor lines since they raise a different exception.
I dont know why its not commented in the snippet I pasted but I did comment those lines, so even if the data.shape[-3] is wrong those lines should not run at all.

P.S: -3 requires at least 3 dimensions since the last one is -1.

Looks like we had a bug with PYTORCH_JIT=0, it’s fixed in #20120.

Can you post the full code you are using to run this that generates the error?

Sure thing. Its just a simple test:

    def test_normalize(self):

        # prepare input data
        data = torch.ones(1, 2, 2)
        mean = torch.tensor([0.5])
        std = torch.tensor([2.0])

        # expected output
        expected = torch.tensor([0.25]).repeat(1, 2, 2).view_as(data)

        f = image.Normalize(mean, std)
        assert_allclose(f(data), expected)

And this is the normalize class definition

class Normalize(nn.Module):

    """
    Normalize a tensor image or a batch of tensor images
    with mean and standard deviation. Input must be a tensor of shape (C, H, W)
    or a batch of tensors (*, C, H, W).
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
    this transform will normalize each channel of the input ``torch.*Tensor``
    i.e. ``input[channel] = (input[channel] - mean[channel]) / std[channel]``

    Args:
        mean (torch.Tensor): Mean for each channel.
        std (torch.Tensor): Standard deviation for each channel.
    """

    def __init__(self, mean: torch.Tensor, std: torch.Tensor) -> None:

        super(Normalize, self).__init__()

        self.mean = mean
        self.std = std

    def forward(self, input: torch.Tensor) -> torch.Tensor:  # type: ignore
        return normalize(input, self.mean, self.std)

    def __repr__(self):
        repr = '(mean={0}, std={1})'.format(self.mean, self.std)
        return self.__class__.__name__ + repr

P.S: This and every other test pass perfectly without jit