Hi,
I found a strange behavior (maybe it’s normal, idk) during a JIT conversion of one of my model.
When I use the jit capabilities to export my model with torch.jit.trace(model, torch.randn(1, 2, 10, 10, 10))
, if I have a torch.nn.functional.interpolate(x, scale_factor=2, model="trilinear", align_corners=True)
inside the forward pass, the jit model seems to be working with an input of size (1, 2, 10, 10, 10)
strictly.
Here is a small script to reproduce the behavior:
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv3d(5, 1, 3, padding=1, bias=False)
def forward(self, x):
new_x = self.conv(x)
up_x = torch.nn.functional.interpolate(
new_x, scale_factor=2, mode="trilinear", align_corners=True)
return up_x
inp_5 = torch.randn(1, 5, 5, 5, 5)
inp_10 = torch.randn(1, 5, 10, 10, 10)
inp_15 = torch.randn(1, 5, 15, 15, 15)
model = Model()
model.eval()
trace = torch.jit.trace(model, inp_10)
trace.save("trace.pth")
result_model_5 = model(inp_5)
result_model_10 = model(inp_10)
result_model_15 = model(inp_15)
t_model = torch.jit.load("trace.pth")
result_t_model_5 = t_model(inp_5)
result_t_model_10 = t_model(inp_10)
result_t_model_15 = t_model(inp_15)
print("Shape 5, {} ||| {}".format(result_model_5.shape, result_t_model_5.shape))
print("Shape 10, {} ||| {}".format(result_model_10.shape, result_t_model_10.shape))
print("Shape 15, {} ||| {}".format(result_model_15.shape, result_t_model_15.shape))
torch.allclose(result_model_5, result_t_model_5)
torch.allclose(result_model_10, result_t_model_10)
torch.allclose(result_model_15, result_t_model_15)
Outputs:
Shape 5, torch.Size([1, 1, 10, 10, 10]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 10, torch.Size([1, 1, 20, 20, 20]) ||| torch.Size([1, 1, 20, 20, 20])
Shape 15, torch.Size([1, 1, 30, 30, 30]) ||| torch.Size([1, 1, 20, 20, 20])
Traceback (most recent call last):
File "main.py", line 37, in <module>
torch.allclose(result_model_5, result_t_model_5)
RuntimeError: The size of tensor a (10) must match the size of tensor b (20) at non-singleton dimension 4
Is it normal behavior?