Hi everyone, with torch we had loadcaffe
which was extremely useful. Is there any way to achieve the same functionality in pytorch? Simply how can we load caffe models in pytorch?
Thanks a lot!
Hi everyone, with torch we had loadcaffe
which was extremely useful. Is there any way to achieve the same functionality in pytorch? Simply how can we load caffe models in pytorch?
Thanks a lot!
One way would be to use loadcaffe
to load weights in Lua, save them, and use torch.utils.serialization.load_lua
to load them in PyTorch. Unfortunately we don’t have a ready solution for that yet.
Thanks Adam,
hope there is some plan for future solution?
Not anytime soon I’m afraid. There are a lot of high priority tasks.
But it should be quite simple for someone to add that. You just need to read protobufs and translate graphs into modules.
gotcha ya. Thanks a lot for the suggestions.
@apaszke,you means that I can load caffe models in Pytorch using torch?
Yes it should be possible. But these will only be torch.legacy.nn
models, not torch.nn
networks.
Hello, I am not familiar with the caffe framework. https://github.com/alexgkendall/caffe-posenet/blob/master/models/bvlc_googlenet/deploy.prototxt
I want to convert this code into pytorch… I am using arch linux. Would that be possible?
I have finally fixed my problem
Thanks to everyone, and I integrate it into my project
Good evening,
Following your advice apaszke, I downloaded loadcaffe, and transformed the caffe model + prototxt file into a model.t7 file.
I am using this to take this model from caffe to pytorch. Specifically, I am going with the age estimation variant.
require ‘loadcaffe’
import torchmodel = loadcaffe.load(‘age_train.prototxt’, ‘dex_chalearn_iccv2015.caffemodel’, ‘nn’)
torch.save(‘imdb.t7’, model)
The model which gets created is:
nn.Sequential {
[input → (1) → (2) → (3) → (4) → (5) → (6) → (7) → (8) → (9) → (10) → (11) → (12) → (13) → (14) → (15) → (16) → (17) → (18) → (19) → (20) → (21) → (22) → (23) → (24) → (25) → (26) → (27) → (28) → (29) → (30) → (31) → (32) → (33) → (34) → (35) → (36) → (37) → (38) → (39) → output]
(1): nn.SpatialConvolution(3 → 64, 3x3, 1,1, 1,1)
(2): nn.ReLU
(3): nn.SpatialConvolution(64 → 64, 3x3, 1,1, 1,1)
(4): nn.ReLU
(5): nn.SpatialMaxPooling(2x2, 2,2)
(6): nn.SpatialConvolution(64 → 128, 3x3, 1,1, 1,1)
(7): nn.ReLU
(8): nn.SpatialConvolution(128 → 128, 3x3, 1,1, 1,1)
(9): nn.ReLU
(10): nn.SpatialMaxPooling(2x2, 2,2)
(11): nn.SpatialConvolution(128 → 256, 3x3, 1,1, 1,1)
(12): nn.ReLU
(13): nn.SpatialConvolution(256 → 256, 3x3, 1,1, 1,1)
(14): nn.ReLU
(15): nn.SpatialConvolution(256 → 256, 3x3, 1,1, 1,1)
(16): nn.ReLU
(17): nn.SpatialMaxPooling(2x2, 2,2)
(18): nn.SpatialConvolution(256 → 512, 3x3, 1,1, 1,1)
(19): nn.ReLU
(20): nn.SpatialConvolution(512 → 512, 3x3, 1,1, 1,1)
(21): nn.ReLU
(22): nn.SpatialConvolution(512 → 512, 3x3, 1,1, 1,1)
(23): nn.ReLU
(24): nn.SpatialMaxPooling(2x2, 2,2)
(25): nn.SpatialConvolution(512 → 512, 3x3, 1,1, 1,1)
(26): nn.ReLU
(27): nn.SpatialConvolution(512 → 512, 3x3, 1,1, 1,1)
(28): nn.ReLU
(29): nn.SpatialConvolution(512 → 512, 3x3, 1,1, 1,1)
(30): nn.ReLU
(31): nn.SpatialMaxPooling(2x2, 2,2)
(32): nn.View(-1)
(33): nn.Linear(25088 → 4096)
(34): nn.ReLU
(35): nn.Dropout(0.500000)
(36): nn.Linear(4096 → 4096)
(37): nn.ReLU
(38): nn.Dropout(0.500000)
(39): nn.Linear(4096 → 101)
When I try to use torch.utils.serialization.load_lua(‘imdb.t7’) the result is that the .t7 model is corrupted.
As such, I was hoping to get some guidance.
Here is the error code:
Traceback (most recent call last):
File “”, line 1, in
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 608, in load_lua
return reader.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 593, in read
return self.read_object()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 546, in read_object
return reader_registry[cls_name](self, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 243, in read_nn_class
attributes = reader.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 595, in read
return self.read_table()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 572, in read_table
v = self.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 595, in read
return self.read_table()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 572, in read_table
v = self.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 593, in read
return self.read_object()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 546, in read_object
return reader_registry[cls_name](self, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 318, in wrapper
obj = build_fn(reader, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 318, in wrapper
obj = build_fn(reader, version)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 243, in read_nn_class
attributes = reader.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 595, in read
return self.read_table()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 523, in wrapper
result = fn(self, *args, **kwargs)
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 571, in read_table
k = self.read()
File “D:\miniCondPy\envs\221\lib\site-packages\torch\utils\serialization\read_lua_file.py”, line 598, in read
“corrupted.”.format(typeidx))
torch.utils.serialization.read_lua_file.T7ReaderException: unknown type id -1156370551. The file may be corrupted.
Thank you for your time!
Similarly, Neo_li’s solution is not working.