I am having a pytorch model with the following network structure.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class lenet_mnist(nn.Module):
def __init__(self):
super(lenet_mnist, self).__init__()
self.cuda()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3,padding=1)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.relu3 = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5,padding=2)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self,x):
#x = x.to(torch.device("cuda:0"))
out = self.conv1(x)
out = self.maxpool1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.conv2_drop(out)
out = self.maxpool2(out)
out = self.relu2(out)
out = out.view(-1,1568)
out = self.fc1(out)
out = self.fc2(out)
return out
I train this model and I remove few channels in the model, for example I made the self.conv2 as (in=16,out=24). I reloaded the state_dict model and modified the model for the new structure.
Now , I am using torch.onnx function to convert this pytorch model to onnx. Internally, this does calls torch.jit.get_trace and this in turn opens the model.py file for getting the model information. Because of thinning the model’s shape had changed and it does not match the original defined in model.py. How can I avoid this?
edit (more information): For thinning the model, I just get the state_dict of the model, edit the tensors by deleting few channels in it and then use load_state_dict on the model to load the model again. I also change the model._parameters with the updated weights and bias. I use this model to convert to onnx model. While converting, I get an error saying mismatch in dimensions at out.view layer. And, the backtrace shows that the pytorch is reading the model from model.py file where the above code snippet is present. I have pasted the backtrace for the error below.
File "/folder_1/thin_model/dev/thin_model", line 543, in convert_pytorch
torch.onnx.export(model, dummy_input, output_dir + "converted_model.onnx",input_names = input_names)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/onnx/__init__.py", line 25, in export
return utils.export(*args, **kwargs)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/onnx/utils.py", line 84, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/onnx/utils.py", line 134, in _export
trace, torch_out = torch.jit.get_trace_graph(model, args)
File "/usr/lib/python3.5/contextlib.py", line 77, in __exit__
self.gen.throw(type, value, traceback)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/onnx/utils.py", line 38, in set_training
yield
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/onnx/utils.py", line 134, in _export
trace, torch_out = torch.jit.get_trace_graph(model, args)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/jit/__init__.py", line 255, in get_trace_graph
return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/jit/__init__.py", line 288, in forward
out = self.inner(*trace_inputs)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self._slow_forward(*input, **kwargs)
File "/folder_1/python_dir/lib/python3.5/site-packages/torch/nn/modules/module.py", line 479, in _slow_forward
result = self.forward(*input, **kwargs)
File "/folder_1/trained_model/pytorch/simple_network/model.py", line 44, in forward
out = out.view(-1,1568)
RuntimeError: invalid argument 2: size '[-1 x 1568]' is invalid for input with 1176 elements at /pytorch/aten/src/TH/THStorage.c:37