when I try to import the following torch model to pytorch,
require 'torch’
require 'nn’
require 'nnx’
require 'optim’
require ‘rnn’
function buildModel_MeanPool_RNN(nFltrs1,nFltrs2,nFltrs3,nPersonsTrain)
local nFilters = {nFltrs1,nFltrs2,nFltrs3}
local filtsize = {5,5,5}
local poolsize = {2,2,2}
local stepSize = {2,2,2}
-- remember this adds padding to ALL SIDES of the image
local padDim = 4
local cnn = nn.Sequential()
local ninputChannels = 5
cnn:add(nn.SpatialZeroPadding(padDim, padDim, padDim, padDim))
cnn:add(nn.SpatialConvolutionMM(ninputChannels, nFilters[1], filtsize[1], filtsize[1], 1, 1))
cnn:add(nn.Tanh())
cnn:add(nn.SpatialMaxPooling(poolsize[1],poolsize[1],stepSize[1],stepSize[1]))
ninputChannels = nFilters[1]
cnn:add(nn.SpatialZeroPadding(padDim, padDim, padDim, padDim))
cnn:add(nn.SpatialConvolutionMM(ninputChannels, nFilters[2], filtsize[2], filtsize[2], 1, 1))
cnn:add(nn.Tanh())
cnn:add(nn.SpatialMaxPooling(poolsize[2],poolsize[2],stepSize[2],stepSize[2]))
ninputChannels = nFilters[2]
cnn:add(nn.SpatialZeroPadding(padDim, padDim, padDim, padDim))
cnn:add(nn.SpatialConvolutionMM(ninputChannels, nFilters[3], filtsize[3], filtsize[3], 1, 1))
cnn:add(nn.Tanh())
cnn:add(nn.SpatialMaxPooling(poolsize[3],poolsize[3],stepSize[3],stepSize[3]))
local nFullyConnected = nFilters[3]*10*8
cnn:add(nn.Reshape(1,nFullyConnected))
cnn:add(nn.Dropout(0.6))
cnn:add(nn.Linear(nFullyConnected,128))
-- cnn:cuda()
local h2h = nn.Sequential()
h2h:add(nn.Tanh())
h2h:add(nn.Dropout(0.6))
h2h:add(nn.Linear(128,128))
-- h2h:cuda()
local r1 = nn.Recurrent(
128,
cnn,
h2h,
nn.Identity(),
16)
local rnn1 = nn.Sequencer(
nn.Sequential()
:add(r1)
)
Combined_CNN_RNN_1 = nn.Sequential()
Combined_CNN_RNN_1:add(rnn1)
Combined_CNN_RNN_1:add(nn.JoinTable(1))
Combined_CNN_RNN_1:add(nn.Mean(1))
local r2 = nn.Recurrent(
128,
cnn:clone('weight','bias','gradWeight','gradBias'),
h2h:clone('weight','bias','gradWeight','gradBias'),
nn.Identity(),
16)
local rnn2 = nn.Sequencer(
nn.Sequential()
:add(r2)
)
Combined_CNN_RNN_2 = nn.Sequential()
Combined_CNN_RNN_2:add(rnn2)
Combined_CNN_RNN_2:add(nn.JoinTable(1))
Combined_CNN_RNN_2:add(nn.Mean(1))
-- Combined_CNN_RNN_2 = Combined_CNN_RNN_1:clone('weight','bias','gradWeight','gradBias')
local mlp2 = nn.ParallelTable()
mlp2:add(Combined_CNN_RNN_1)
mlp2:add(Combined_CNN_RNN_2)
-- mlp2:cuda()
local mlp3 = nn.ConcatTable()
mlp3:add(nn.Identity())
mlp3:add(nn.Identity())
mlp3:add(nn.Identity())
-- mlp3:cuda()
local mlp4 = nn.ParallelTable()
mlp4:add(nn.Identity())
mlp4:add(nn.SelectTable(1))
mlp4:add(nn.SelectTable(2))
-- mlp4:cuda()
-- used to predict the identity of each person
local classifierLayer = nn.Linear(128,nPersonsTrain)
-- identification
local mlp6 = nn.Sequential()
mlp6:add(classifierLayer)
mlp6:add(nn.LogSoftMax())
-- mlp6:cuda()
local mlp7 = nn.Sequential()
mlp7:add(classifierLayer:clone('weight','bias','gradWeight','gradBias'))
mlp7:add(nn.LogSoftMax())
-- mlp7:cuda()
local mlp5 = nn.ParallelTable()
mlp5:add(nn.PairwiseDistance(2))
mlp5:add(mlp6)
mlp5:add(mlp7)
-- mlp5:cuda()
local fullModel = nn.Sequential()
fullModel:add(mlp2)
fullModel:add(mlp3)
fullModel:add(mlp4)
fullModel:add(mlp5)
-- fullModel:cuda()
local crit = nn.SuperCriterion()
crit:add(nn.HingeEmbeddingCriterion(2),1)
crit:add(nn.ClassNLLCriterion(),1)
crit:add(nn.ClassNLLCriterion(),1)
return fullModel, crit, Combined_CNN_RNN_1, cnn
end
fullModel, crit, Combined_CNN_RNN_1, cnn = buildModel_MeanPool_RNN(16,32,32,150)
torch.save(‘Combined_CNN_RNN.t7’,Combined_CNN_RNN_1)
then in pytorch,
import torch
from torch.utils.serialization import load_lua
net = load_lua(‘Combined_CNN_RNN.t7’)
But it is not working:
Traceback (most recent call last):
File “/mnt/68FC8564543F417E/Pytorch/convert_torch_to_pytorch-master/load_torch_net.py”, line 4, in
net = load_lua(‘Combined_CNN_RNN.t7’)
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 599, in load_lua
return reader.read()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 584, in read
return self.read_object()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 514, in wrapper
result = fn(self, *args, **kwargs)
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 537, in read_object
return reader_registry[cls_name](self, version)
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 242, in read_nn_class
attributes = reader.read()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 586, in read
return self.read_table()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 514, in wrapper
result = fn(self, *args, **kwargs)
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 563, in read_table
v = self.read()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 586, in read
return self.read_table()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 514, in wrapper
result = fn(self, *args, **kwargs)
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 563, in read_table
v = self.read()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 584, in read
return self.read_object()
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 514, in wrapper
result = fn(self, *args, **kwargs)
File “/home/chengli/torch-env/local/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.py”, line 543, in read_object
"constructor").format(cls_name))
torch.utils.serialization.read_lua_file.T7ReaderException: don’t know how to deserialize Lua class nn.Sequencer. If you want to ignore this error and load this object as a dict, specify unknown_classes=True in reader’s constructor
what should I do to make it work?