How to load the prebuilt Resnet models (or any other prebuilt models)

Hi All,

Can you please let me know how to load the pretrained models using c++ frontend? I am trying to load the Resnet model using C++ front end as follows.

#include <torch/torch.h>
#include <torch/data/datasets/base.h>
#include <iostream>
#include <memory>

int main()
{
  //https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth
  //resnet18-5c106cde.pth

  //auto sequential = std::make_shared<torch::nn::Sequential>();
  torch::nn::Sequential sequential;

  torch::load(sequential, "resnet18-5c106cde.pth");

  std::cout << c10::str(sequential) << "\n";
}

But I am getting sigabrt.

[ Variable[CPUFloatType]{2,3} ]
terminate called after throwing an instance of 'c10::Error'
  what():  [enforce fail at inline_container.cc:137] . PytorchStreamReader failed reading zip archive: failed finding central directory
frame #0: std::function<std::string ()>::operator()() const + 0x11 (0x7f0189f2ef81 in /opt/libtorch/lib/libc10.so)
frame #1: c10::ThrowEnforceNotMet(char const*, int, char const*, std::string const&, void const*) + 0x49 (0x7f0189f2ed99 in /opt/libtorch/lib/libc10.so)
frame #2: caffe2::serialize::PyTorchStreamReader::valid(char const*) + 0x6b (0x7f018bb3894b in /opt/libtorch/lib/libcaffe2.so)
frame #3: caffe2::serialize::PyTorchStreamReader::init() + 0x9d (0x7f018bb3a71d in /opt/libtorch/lib/libcaffe2.so)
frame #4: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::unique_ptr<caffe2::serialize::ReadAdapterInterface, std::default_delete<caffe2::serialize::ReadAdapterInterface> >) + 0x3b (0x7f018bb3c17b in /opt/libtorch/lib/libcaffe2.so)
frame #5: <unknown function> + 0x9087ee (0x7f0194c137ee in /opt/libtorch/lib/libtorch.so.1)
frame #6: torch::jit::load(std::unique_ptr<caffe2::serialize::ReadAdapterInterface, std::default_delete<caffe2::serialize::ReadAdapterInterface> >, c10::optional<c10::Device>, std::unordered_map<std::string, std::string, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::string> > >&) + 0x9a (0x7f0194c16d2a in /opt/libtorch/lib/libtorch.so.1)
frame #7: torch::jit::load(std::string const&, c10::optional<c10::Device>, std::unordered_map<std::string, std::string, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::string> > >&) + 0x68 (0x7f0194c16ec8 in /opt/libtorch/lib/libtorch.so.1)
frame #8: torch::serialize::InputArchive::load_from(std::string const&, c10::optional<c10::Device>) + 0x38 (0x7f0194e09798 in /opt/libtorch/lib/libtorch.so.1)
frame #9: void torch::load<torch::nn::Sequential, char const (&) [22]>(torch::nn::Sequential&, char const (&) [22]) + 0x57 (0x409cd7 in ./CatsAndDogsCnn)
frame #10: main + 0x9e (0x4074ae in ./CatsAndDogsCnn)
frame #11: __libc_start_main + 0xf0 (0x7f01895c9830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #12: _start + 0x29 (0x407599 in ./CatsAndDogsCnn)

Hey @SantMan,

this should get you going again. First, use python to convert the network into a ScriptModule. For example call this convert.py

import torch
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

# use an example input to trace the operations of the model
example_input = torch.rand(1, 3, 224, 224) # 224 is the least input size, depends on the dataset you use

script_module = torch.jit.trace(resnet18, example_input)
script_module.save('script_module.pt')

Then, run it. You can now access the ScriptModule from C++ code like so

#include <torch/script.h>

int main()
{
	std::shared_ptr<torch::jit::script::Module> module;
	module = torch::jit::load("../script_module.pt"); 

        // Example forward.
        std::vector<torch::jit::IValue> input;
        input.push_back(torch::zeros({1, 3, 224, 224}));

        torch::Tensor output = module->forward(input).toTensor();

	return 0;
}

Where my folder structure is shown below

../convert.py
   CMakeLists.txt
   main.cpp
   /build

and the CMakeLists.txt looks as follows

cmake_minimum_required(VERSION 3.11 FATAL_ERROR)

project(load_resnet)

find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})

add_executable(main main.cpp)
target_link_libraries(main ${TORCH_LIBRARIES})
1 Like

Hi @mhubii,

Thanks for the inputs. I did a quick try on the suggested method. It is working. I want to try it out on Cats and Dogs calssifier. Will let you know how it goes. Meanwhile can you please let me know if there is any alternate method of not using python to serialize the data? :slight_smile: I mean even if it is a work in progess, I would love to know about the same.

Cheers,
SantMan

Hi @mhubii,

I have couple of questions on the loading the script_module.pt

Here is what I want to do in cpp. (pytorch equivalent code)

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

model = models.densenet121(pretrained=True)
model

The model has 2 parts “Sequenctial” and “Classfier”.

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )

........
........
........
(classifier): Linear(in_features=1024, out_features=1000, bias=True)

I want freez the “Sequential” layer and add my own “Classifier” layer. as follows.

# Freeze parameters so we don't backprop through them
for param in model.parameters():
    param.requires_grad = False

from collections import OrderedDict
classifier = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(1024, 500)),
                          ('relu', nn.ReLU()),
                          ('fc2', nn.Linear(500, 2)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))
    
model.classifier = classifier

I am not able to do the same using C++. This is basically because I’m not able to load the script model

module = torch::jit::load("../script_module.pt"); 

as torc::nn:Module or torch::nn::Sequential. Any inputs on how to achieve the same?

Thanks and Regards,
SantMan

You should be able to load the ScriptModule if you trace it again. Are you trying to do the finetuning in C++?

Hi @mhubii,

No I’m not trying to finetune the prebuilt models. I’m trying to re-use the prebuilt models according to my requirement.

If you see the network it has 1024 input feature and 1000 output feature.

(classifier): Linear(in_features=1024, out_features=1000, bias=True)

Where as I have only 2 features as output (Cats and dogs). For this purpose I need to change the classfiier according to my output features i.e 2. Now at the same time I don’t want to re-train the complete network. Hence I know I should freeze the network and change only the classfier part to avoid any complete re-training of network. I know how to do all these steps in pytorch. Which I"ve already posted in my previous comment.

Now, when I try to do the same using C++ front end I’m not able to do so. This is because the network I’m trying to load is stored of the form torch::jit::script::Module. It just lets me load the stored model and I am not able to modify the output features to 2 instead of 1000. I think this is because I’m not able to convert the torch::jit::script::Module to either torch::nn::Module or torch::nn::Sequential.

I want to know if there is any way in which I can convert the torch::jit::script::Module to either torch::nn::Module or torch::nn::Sequential.

I know while converting the model from torch.jit.script() in convert.py we can change it and store it and load the completed model in C++, that may work (I’m yet to try that). But I am more interested in the solution where I should be able to convert torch::jit::script::Module to either torch::nn::Module or torch::nn::Sequential.

Please let me know if I’m able to state my problem correctly. I will try to explain more if required.

Thanks and Regards,
SantMan

okay I see. To my understanding, the members of torch::jit::script::Module are private and you’ll only be able to do inference on the loaded module. Maybe you could try to remove the last layer in python with

# remove the last layer
resnet18 = nn.Sequential(*list(resnet18.children())[:-1])

then, trace resnet18 and load the module with the removed last layer to C++, where you could add a linear layer to it and pass the paramters of the linear layer to the optimizer. Something like

std::shared_ptr<torch::jit::script::Module> module;
module = torch::jit::load("../script_module.pt");

// replacement for the last layer
torch::nn::Linear lin(512 , 2); // the last layer of resnet, which you want to replace, has dimensions 512x1000
torch::optim::Adam opt(lin->parameters(), torch::optim::AdamOptions(1e-3 /*learning rate*/));

// example input
std::vector<torch::jit::IValue> input;
input.push_back(torch::zeros({1, 3, 224, 224}));
	
// your cat/dog labels
torch::Tensor label = torch::zeros({2}); 
label[0] = 1.; // one hot encoder

// some training loop
torch::Tensor out = module->forward(input).toTensor().squeeze();
out = lin(out);
torch::Tensor loss = torch::mse_loss(out, label);
opt.zero_grad();
loss.backward();
opt.step();

hope it helps. It may feel a little hacky at the moment, but I am sure that there will be some actual models of popular architectures for future releases. Please report back if it works for you :slightly_smiling_face::+1:

1 Like

hi @mhubii,

Thanks for the response. I will try that out and get back with my results or doubts :slight_smile:

Best Regards,
SantMan

1 Like