Fully convolutional jit module variable input libtorch

Hi all,

I have an issue with a torchScript module traced from torchvision.models.segmentation.deeplabv3_resnet50. Since it only contains operations that do not fix input size, I was hoping to be able to load it into C++ and use it for inference on different sized 2D input.

Basically, in python I create the model, (train, not relevant for the error), trace it and save it. Then in C++ I load and try to do inference. At runtime, the size of input can vary and I thought that this would not be a problem, but I always get Microsoft C++ exception: cudnn_frontend::cudnnException at memory location 0x00000052344FBEA0. for any input that is not the same shape as the example_input provided to trace

Am I
a) wrong in my understanding that the model should be able to accept variable shaped input
b) doing something wrong

Thank you for your help. Example code and library versions below for reproduction. Let me know if anything else is needed

In python (torch==2.0.0+cu118, torchvision==0.15.1+cu118 on win64), I do:

import torch
import torchvision

class wrapper(torch.nn.Module):
    def __init__(self, model):
        super(wrapper, self).__init__()
        self.model = model

    def forward(self, input):
        results = []
        output = self.model(input)
        return output["out"]

numClasses=29

net = torchvision.models.segmentation.deeplabv3_resnet50()
net.classifier[4] = torch.nn.Conv2d(256, numClasses, kernel_size=(1, 1), stride=(1, 1))
net.eval()
model = wrapper(net)

device = torch.device("cuda")
model = model.to(device)
with torch.no_grad():
    raw_t = torch.rand((1,3, 2048, 2448 ), device=device)

    traced_script_module = torch.jit.trace(model, raw_t)
    traced_script_module.save("D:\\Models\\model_2048_2448_cuda.pt")

and then in C++ (libtorch 2.0.0+cu118 in visual studio 2019 v142):

void testjit() {
	auto model = torch::jit::load("D:/Models/model_2048_2448_cuda.pt");
	model.to(torch::kCUDA);
	model.eval();
	if (true)
	{
		torch::NoGradGuard guard;
		torch::Tensor inp = torch::rand({ 1,3, 2048, 2448 }, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
		std::vector<torch::jit::IValue> inputs;
		inputs.push_back(inp);

		auto full_image_start = std::chrono::high_resolution_clock::now();

		model.forward(inputs); //No problem, same shape as used for tracing in python
		std::cout << "2048, 2448 ran succesfully" << std::endl;
	}
	if (true)
	{
		torch::Tensor inp = torch::rand({ 1,3,1444,1444 }, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
		std::vector<torch::jit::IValue> inputs;
		inputs.push_back(inp);
		model.forward(inputs);
		std::cout << "1444,1444 ran succesfully" << std::endl;
	}
	if(true)
	{
		torch::Tensor inp = torch::rand({ 1,3,1444,1443 }, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
		std::vector<torch::jit::IValue> inputs;
		inputs.push_back(inp);
		model.forward(inputs); //Microsoft C++ exception: cudnn_frontend::cudnnException at memory location 0x00000052344FBEA0.
		std::cout << "1444,1443 ran succesfully" << std::endl;
	}

Probably yes, since you are tracing the model, which will record the used code path only without understanding conditions, etc.
You could check how variable input shapes are handled internally in the model and check if this logic needs conditions etc. If so, you could then try to torch.jit.script the model.

However, also note that TorchScript is in “maintenance” mode and could thus easily break.

I have seen that the model uses AdaptiveAvgPool2d at some point. That probably doen’t play nicely with tracing the model for a fixed input, as it needs to have conditions internally to handle variable input sizes?

I don’t believe adaptive pooling layers are causing the issue as seen in this small example:

pool = nn.AdaptiveAvgPool2d(2).cuda()
x = torch.randn(1, 3, 24, 24, device="cuda")

out = pool(x)
print(out.shape)
# torch.Size([1, 3, 2, 2])

traced = torch.jit.trace(pool, x)

out_traced = traced(x)
print(out_traced.shape)
# torch.Size([1, 3, 2, 2])

x = torch.randn(1, 3, 3, 3, device="cuda")
out_traced = traced(x)
print(out_traced.shape)
# torch.Size([1, 3, 2, 2])

so maybe other ops are failing in your model.

Thank you for the insight, you are indeed right. I have gone through the torchvision source and it looks like the issue stems from the _SimpleSegmentationModel base class in models/segmentation/_utils. The forward function is defined as follows:

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
        result["out"] = x

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
            result["aux"] = x

        return result

where the interpolate depends on the input size. I tested it by leaving out this function and indeed running different shaped inputs now in C++ works (with an unexpected twist, see below). What I did in python is:

import torch
import torchvision

class wrappWithoutInterpolate(torch.nn.Module):
    def __init__(self, model):
        super(wrappWithoutInterpolate, self).__init__()
        self.model = model

    def forward(self, x):
        input_shape = x.shape[-2:]
        # leave out interpolation and only keep the output of the classifier, not the auxiliary classifier
        features = self.model.backbone(x)
        x = features["out"]
        x = self.model.classifier(x)
        return x

numClasses=29

net = torchvision.models.segmentation.deeplabv3_resnet50()
net.classifier[4] = torch.nn.Conv2d(256, numClasses, kernel_size=(1, 1), stride=(1, 1))
net.eval()
net = net.to(torch.device('cuda'))

with torch.no_grad():
      raw_t = torch.rand((1,3, 2048, 2448 ), device=torch.device('cuda'))
      traced_script_module = torch.jit.trace(wrapped, raw_t)
      traced_script_module.save("D:\\Models\\test_wrapped_resnet.pt")

In C++ I can now do the following, but need to clone the model for some reason (Any idea why @ptrblck ?)

int main(){
	
	auto model = torch::jit::load("D:\\Models\\test_wrapped_resnet.pt");
	model.eval();
	model.to(torch::kCUDA);
	{
		torch::NoGradGuard _guard;
                //If I don't clone here, the second run with different inputs simply runs forever
		auto tmpModel = model.clone();
		torch::Tensor wrong = torch::randn({ 1, 3, 1024, 1024 }, torch::TensorOptions().device(torch::kCUDA));
		std::vector<torch::jit::IValue> inputs_wrong;
		inputs_wrong.push_back(wrong);
		auto tmp_wrong = tmpModel.forward(inputs_wrong);
		auto ret_wrong = tmp_wrong.toTensor();
		auto final_ret = torch::nn::functional::interpolate(ret_wrong,
			torch::nn::functional::InterpolateFuncOptions().mode(torch::kBilinear).size(std::vector<int64_t>{wrong.size(2), wrong.size(3) }).align_corners(false));

	}
	{
		auto tmpModel = model.clone();
		torch::NoGradGuard _guard;
		torch::Tensor correct = torch::randn({ 1, 3, 2048, 2448 }, torch::TensorOptions().device(torch::kCUDA));
		std::vector<torch::jit::IValue> inputs_correct;
		inputs_correct.push_back(correct);
		auto tmp_correct = tmpModel.forward(inputs_correct);
		auto ret_correct = tmp_correct.toTensor();
		auto final_ret = torch::nn::functional::interpolate(ret_correct,
	            torch::nn::functional::InterpolateFuncOptions().mode(torch::kBilinear).size(std::vector<int64_t>{correct.size(2), correct.size(3) }).align_corners(false));

	}
}