JIT with torch.nn.functional.interpolate

Hi,
I found a strange behavior (maybe it’s normal, idk) during a JIT conversion of one of my model.

When I use the jit capabilities to export my model with torch.jit.trace(model, torch.randn(1, 2, 10, 10, 10)), if I have a torch.nn.functional.interpolate(x, scale_factor=2, model="trilinear", align_corners=True) inside the forward pass, the jit model seems to be working with an input of size (1, 2, 10, 10, 10) strictly.

Here is a small script to reproduce the behavior:

import torch


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv3d(5, 1, 3, padding=1, bias=False)

    def forward(self, x):
        new_x = self.conv(x)
        up_x = torch.nn.functional.interpolate(
            new_x, scale_factor=2, mode="trilinear", align_corners=True)
        return up_x


inp_5 = torch.randn(1, 5, 5, 5, 5)
inp_10 = torch.randn(1, 5, 10, 10, 10)
inp_15 = torch.randn(1, 5, 15, 15, 15)

model = Model()
model.eval()
trace = torch.jit.trace(model, inp_10)
trace.save("trace.pth")

result_model_5 = model(inp_5)
result_model_10 = model(inp_10)
result_model_15 = model(inp_15)

t_model = torch.jit.load("trace.pth")
result_t_model_5 = t_model(inp_5)
result_t_model_10 = t_model(inp_10)
result_t_model_15 = t_model(inp_15)

print("Shape  5, {} ||| {}".format(result_model_5.shape, result_t_model_5.shape))
print("Shape 10, {} ||| {}".format(result_model_10.shape, result_t_model_10.shape))
print("Shape 15, {} ||| {}".format(result_model_15.shape, result_t_model_15.shape))
torch.allclose(result_model_5, result_t_model_5)
torch.allclose(result_model_10, result_t_model_10)
torch.allclose(result_model_15, result_t_model_15)

Outputs:

Shape  5, torch.Size([1, 1, 10, 10, 10]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 10, torch.Size([1, 1, 20, 20, 20]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 15, torch.Size([1, 1, 30, 30, 30]) ||| torch.Size([1, 1, 20, 20, 20])
Traceback (most recent call last):
  File "main.py", line 37, in <module>
    torch.allclose(result_model_5, result_t_model_5)
RuntimeError: The size of tensor a (10) must match the size of tensor b (20) at non-singleton dimension 4

Is it normal behavior?

Seems Odd,

import torch


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv3d(5, 1, 3, padding=1, bias=False)

    def forward(self, x):
        new_x = self.conv(x)
        up_x = torch.nn.functional.interpolate(
            new_x, scale_factor=2, mode="trilinear", align_corners=True)
        return up_x


inp_5 = torch.randn(1, 5, 5, 5, 5)
inp_10 = torch.randn(1, 5, 10, 10, 10)
inp_15 = torch.randn(1, 5, 15, 15, 15)

model = Model()
model.eval()
trace = torch.jit.trace(model, inp_10)
trace.save("trace.pth")

result_model_5 = model(inp_5)
result_model_10 = model(inp_10)
result_model_15 = model(inp_15)
print("Shape  5, {} ||| {}".format(result_model_5.shape, result_model_5.shape))
print("Shape 10, {} ||| {}".format(result_model_10.shape, result_model_10.shape))
print("Shape 15, {} ||| {}".format(result_model_15.shape, result_model_15.shape))
t_model = torch.jit.load("trace.pth")
result_t_model_5 = t_model(inp_5)
result_t_model_10 = t_model(inp_10)
result_t_model_15 = t_model(inp_15)

print("Shape  5, {} ||| {}".format(result_model_5.shape, result_t_model_5.shape))
print("Shape 10, {} ||| {}".format(result_model_10.shape, result_t_model_10.shape))
print("Shape 15, {} ||| {}".format(result_model_15.shape, result_t_model_15.shape))
torch.allclose(result_model_5, result_t_model_5)
torch.allclose(result_model_10, result_t_model_10)
torch.allclose(result_model_15, result_t_model_15

Execute this and revert, please.

I don’t see any difference with your proposal:

Shape  5, torch.Size([1, 1, 10, 10, 10]) ||| torch.Size([1, 1, 10, 10, 10])
Shape 10, torch.Size([1, 1, 20, 20, 20]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 15, torch.Size([1, 1, 30, 30, 30]) ||| torch.Size([1, 1, 30, 30, 30])

Shape  5, torch.Size([1, 1, 10, 10, 10]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 10, torch.Size([1, 1, 20, 20, 20]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 15, torch.Size([1, 1, 30, 30, 30]) ||| torch.Size([1, 1, 20, 20, 20])
Traceback (most recent call last):
  File "test_ans.py", line 39, in <module>
    torch.allclose(result_model_5, result_t_model_5)
RuntimeError: The size of tensor a (10) must match the size of tensor b (20) at non-singleton dimension 4

There is still an issue(?) with the model generated with the JIT. When you say revert you mean retest?

Tracing doesn’t understand dynamic control flow, so sometimes it will “constant-ify” shapes in your model. Try turning your model in to a ScriptModule and using TorchScript; it should fix this problem.

1 Like

Thanks for your time!

I made it works with something like this:

class Interpolate(torch.jit.ScriptModule):
    __constants__ = ["scale_factor", "mode", "align_corners"]

    def __init__(self, scale_factor=2.0, mode="nearest", align_corners=None):
        super(Interpolate, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    @torch.jit.script_method
    def forward(self, X):
        return nn.functional.interpolate(X, scale_factor=self.scale_factor,
                                         mode=self.mode, align_corners=self.align_corners)

I finally get the desired output:

Shape  5, torch.Size([1, 1, 10, 10, 10]) ||| torch.Size([1, 1, 10, 10, 10])
Shape 10, torch.Size([1, 1, 20, 20, 20]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 15, torch.Size([1, 1, 30, 30, 30]) ||| torch.Size([1, 1, 30, 30, 30])
True
True
True