How to load caffe models in pytorch

Hi everyone, with torch we had loadcaffe which was extremely useful. Is there any way to achieve the same functionality in pytorch? Simply how can we load caffe models in pytorch?

Thanks a lot!

1 Like

One way would be to use loadcaffe to load weights in Lua, save them, and use torch.utils.serialization.load_lua to load them in PyTorch. Unfortunately we don’t have a ready solution for that yet.

5 Likes

Thanks Adam,
hope there is some plan for future solution?

Not anytime soon I’m afraid. There are a lot of high priority tasks.

But it should be quite simple for someone to add that. You just need to read protobufs and translate graphs into modules.

gotcha ya. Thanks a lot for the suggestions.

@apaszke,you means that I can load caffe models in Pytorch using torch?

Yes it should be possible. But these will only be torch.legacy.nn models, not torch.nn networks.

Hello, I am not familiar with the caffe framework. https://github.com/alexgkendall/caffe-posenet/blob/master/models/bvlc_googlenet/deploy.prototxt
I want to convert this code into pytorch… I am using arch linux. Would that be possible?

I have finally fixed my problem
Thanks to everyone, and I integrate it into my project

Good evening,

Following your advice apaszke, I downloaded loadcaffe, and transformed the caffe model + prototxt file into a model.t7 file.

I am using this to take this model from caffe to pytorch. Specifically, I am going with the age estimation variant.

require ‘loadcaffe’
import torch

model = loadcaffe.load(‘age_train.prototxt’, ‘dex_chalearn_iccv2015.caffemodel’, ‘nn’)
torch.save(‘imdb.t7’, model)

The model which gets created is:

nn.Sequential {
[input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> (24) -> (25) -> (26) -> (27) -> (28) -> (29) -> (30) -> (31) -> (32) -> (33) -> (34) -> (35) -> (36) -> (37) -> (38) -> (39) -> output]
(1): nn.SpatialConvolution(3 -> 64, 3x3, 1,1, 1,1)
(2): nn.ReLU
(3): nn.SpatialConvolution(64 -> 64, 3x3, 1,1, 1,1)
(4): nn.ReLU
(5): nn.SpatialMaxPooling(2x2, 2,2)
(6): nn.SpatialConvolution(64 -> 128, 3x3, 1,1, 1,1)
(7): nn.ReLU
(8): nn.SpatialConvolution(128 -> 128, 3x3, 1,1, 1,1)
(9): nn.ReLU
(10): nn.SpatialMaxPooling(2x2, 2,2)
(11): nn.SpatialConvolution(128 -> 256, 3x3, 1,1, 1,1)
(12): nn.ReLU
(13): nn.SpatialConvolution(256 -> 256, 3x3, 1,1, 1,1)
(14): nn.ReLU
(15): nn.SpatialConvolution(256 -> 256, 3x3, 1,1, 1,1)
(16): nn.ReLU
(17): nn.SpatialMaxPooling(2x2, 2,2)
(18): nn.SpatialConvolution(256 -> 512, 3x3, 1,1, 1,1)
(19): nn.ReLU
(20): nn.SpatialConvolution(512 -> 512, 3x3, 1,1, 1,1)
(21): nn.ReLU
(22): nn.SpatialConvolution(512 -> 512, 3x3, 1,1, 1,1)
(23): nn.ReLU
(24): nn.SpatialMaxPooling(2x2, 2,2)
(25): nn.SpatialConvolution(512 -> 512, 3x3, 1,1, 1,1)
(26): nn.ReLU
(27): nn.SpatialConvolution(512 -> 512, 3x3, 1,1, 1,1)
(28): nn.ReLU
(29): nn.SpatialConvolution(512 -> 512, 3x3, 1,1, 1,1)
(30): nn.ReLU
(31): nn.SpatialMaxPooling(2x2, 2,2)
(32): nn.View(-1)
(33): nn.Linear(25088 -> 4096)
(34): nn.ReLU
(35): nn.Dropout(0.500000)
(36): nn.Linear(4096 -> 4096)
(37): nn.ReLU
(38): nn.Dropout(0.500000)
(39): nn.Linear(4096 -> 101)

When I try to use torch.utils.serialization.load_lua(‘imdb.t7’) the result is that the .t7 model is corrupted.

As such, I was hoping to get some guidance.

Here is the error code:

Traceback (most recent call last):
File “”, line 1, in
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 608, in load_lua
return reader.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 593, in read
return self.read_object()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 546, in read_object
return reader_registry[cls_name](self, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 243, in read_nn_class
attributes = reader.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 595, in read
return self.read_table()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 572, in read_table
v = self.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 595, in read
return self.read_table()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 572, in read_table
v = self.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 593, in read
return self.read_object()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 546, in read_object
return reader_registry[cls_name](self, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 318, in wrapper
obj = build_fn(reader, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 318, in wrapper
obj = build_fn(reader, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 243, in read_nn_class
attributes = reader.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 595, in read
return self.read_table()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 571, in read_table
k = self.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 598, in read
“corrupted.”.format(typeidx))
torch.utils.serialization.read_lua_file.T7ReaderException: unknown type id -1156370551. The file may be corrupted.

Thank you for your time!

Similarly, Neo_li’s solution is not working.