Convert/import Torch model to PyTorch

Hi,

Great library! I’d like to ask if it is possible to import a trained Torch model to PyTorch…

Thanks

4 Likes

As of today, you can deserialize Lua’s .t7 files into PyTorch containing Tensors, numbers, tables, nn models (no nngraph), strings.

We provide the load_lua utility for this purpose.

Here’s an example of saving a tensor in Torch and loading it back in PyTorch


th> a = torch.randn(10)
                                                                      [0.0027s]
th> torch.save('a.t7', a)
                                                                      [0.0010s]
th> a
-1.4479
 1.3707
 0.5663
-1.0590
 0.0706
-1.6495
-1.0805
 0.8277
-0.4595
 0.1237
[torch.DoubleTensor of size 10]

                                                                      [0.0033s]
In [1]: import torch

In [2]: from torch.utils.serialization import load_lua

In [3]: a = load_lua('a.t7')

In [4]: a
Out[4]:

-1.4479
 1.3707
 0.5663
-1.0590
 0.0706
-1.6495
-1.0805
 0.8277
-0.4595
 0.1237
[torch.DoubleTensor of size 10]

Here’s an example of loading a 2 layer sequential neural network:

th> a = nn.Sequential():add(nn.Linear(10, 20)):add(nn.ReLU())
                                                                      [0.0001s]
th> a
nn.Sequential {
  [input -> (1) -> (2) -> output]
  (1): nn.Linear(10 -> 20)
  (2): nn.ReLU
}
                                                                      [0.0001s]
th> torch.save('a.t7', a)
                                                                      [0.0008s]
th>
In [5]: a = load_lua('a.t7')

In [6]: a
Out[6]:
nn.Sequential {
  [input -> (0) -> (1) -> output]
  (0): nn.Linear(10 -> 20)
  (1): nn.ReLU
}

In [7]: a.__class__
Out[7]: torch.legacy.nn.Sequential.Sequential
5 Likes

Hi,

Is there a simple way to convert a torch.legacy.nn module into a torch.nn module ?

5 Likes

No, unfortunately we don’t have an automatic converter at the moment.

1 Like

But it should be quite simple to add.

2 Likes

Hi Guys, I am trying to load a torch7 model to use in PyTorch, this model: https://github.com/e-lab/imagenet-multiGPU.torch/blob/master/models/enet128.lua
But it is not working:

>>> n=load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/model.net')
>>> n.forward(torch.FloatTensor(1,3,128,128))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Module.py", line 32, in forward
    return self.updateOutput(input)
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Sequential.py", line 35, in updateOutput
    currentOutput = module.updateOutput(currentOutput)
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/ConcatTable.py", line 12, in updateOutput
    self.output = [module.updateOutput(input) for module in self.modules]
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/ConcatTable.py", line 12, in <listcomp>
    self.output = [module.updateOutput(input) for module in self.modules]
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/SpatialMaxPooling.py", line 33, in updateOutput
    if self.indices is None:
AttributeError: 'SpatialMaxPooling' object has no attribute 'indices'

also on an AlexNet:

>>> n=load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/model.net')
>>> n
nn.Sequential {
  [input -> (0) -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> output]
  (0): nn.SpatialConvolution(3 -> 64, 11x11, 4, 4, 2, 2)
  (1): nn.ReLU
  (2): nn.SpatialMaxPooling(3x3, 2, 2)
  (3): nn.SpatialConvolution(64 -> 192, 5x5, 1, 1, 2, 2)
  (4): nn.ReLU
  (5): nn.SpatialMaxPooling(3x3, 2, 2)
  (6): nn.SpatialConvolution(192 -> 384, 3x3, 1, 1, 1, 1)
  (7): nn.ReLU
  (8): nn.SpatialConvolution(384 -> 256, 3x3, 1, 1, 1, 1)
  (9): nn.ReLU
  (10): nn.SpatialConvolution(256 -> 256, 3x3, 1, 1, 1, 1)
  (11): nn.ReLU
  (12): nn.SpatialMaxPooling(3x3, 2, 2)
  (13): nn.View(9216)
  (14): nn.Linear(9216 -> 4096)
  (15): nn.ReLU
  (16): nn.Linear(4096 -> 4096)
  (17): nn.ReLU
  (18): nn.Linear(4096 -> 46)
  (19): nn.SoftMax
}
>>> n.forward(torch.FloatTensor(1,3,224,224))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Module.py", line 32, in forward
    return self.updateOutput(input)
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Sequential.py", line 35, in updateOutput
    currentOutput = module.updateOutput(currentOutput)
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Linear.py", line 43, in updateOutput
    assert input.dim() == 2
AssertionError
>>> n.forward(torch.FloatTensor(3,224,224))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Module.py", line 32, in forward
    return self.updateOutput(input)
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Sequential.py", line 35, in updateOutput
    currentOutput = module.updateOutput(currentOutput)
  File "/usr/local/lib/python3.6/site-packages/torch/legacy/nn/Linear.py", line 43, in updateOutput
    assert input.dim() == 2
AssertionError

Can you explain how to run old networks here? Or how to go about doing that?

For Alexnet just do:

from torch.legacy import nn
n.modules[13] = nn.View(1,9216)

(instead of nn.View(9216))

For enet it’s trickier. There are several errors in the pytorch legacy code.

First of all the error you get (‘SpatialMaxPooling’ object has no attribute ‘indices’) is because there is an error in SpatialMaxPooling.py inside pytorch. In line 34, instead of
if self.indices is None:, it should be
if not hasattr(self, ‘indices’):
This probably comes from a wrong Lua->Python conversion of the code.

Then in JoinTable.py line
dimension = self.dimension should become dimension = self.dimension-1

Then in Padding.py nInputDim is totally missing. You can temporarily fix this by changing self.dim in lines 19 and 21 of Padding.py into self.dim+1.

Then you have again the View issue at the end.

After all these changes, it will run.

2 Likes

@mvitez if you have a patch for these errors, could you please send a PR?

Sure, I will. This forum wants 20 characters to answer.

2 Likes

Modify your Torch according to this PR

and process your model.net with this script:

function patch(m)
   if torch.type(m) == 'nn.Padding' and m.nInputDim == 3 then
      m.dim = m.dim+1
      m.nInputDim = 4
   end
   if torch.type(m) == 'nn.View' and #m.size == 1 then
      newsize = torch.LongStorage(2)
      newsize[1] = 1
      newsize[2] = m.size[1]
      m.size = newsize
   end
   if m.modules then
      for i =1,#m.modules do
         patch(m.modules[i])
      end
   end
end

require 'nn'
net = torch.load('model.net')
patch(net)
torch.save('model2.net',net)
2 Likes

Can ‘load_lua’ works for model with cudnn layers ? I am getting some serialization issues in read_lua_file.py. I had to use cudnn.convert(model,nn) and load_lua.

1 Like

No, it doesn’t support cudnn layers at the moment.

1 Like

convert_torch_to_pytorch
This script can convert t7 model file to python source and pytorch pth model.
It supports normal nn/cudnn layers. AlexNet/VGG/ResNet/ResNeXt have been validated.

7 Likes

When using load_lua to load a model pre-trained in torch7 , I get an error KeyError: 'torch.CudaTensor'. Any thoughts on how to overcome this?

I think the serialization of GPU models might not be implemented. A quick workaround would be to load the checkpoint in Lua, cast it to CPU float, and try loading in PyTorch then.

When using load_lua, I got an error AttributeError: type object ‘FloatStorage’ has no attribute ‘from_buffer’, how can I solve this problem?

1 Like

Is there any chance your model contains CUDA tensors?

I got the same problem. Any solutions?

Yes, do you have any CUDA tensors in your model?

Thanks, that solves my problem. Previously I used cudnn.convert(model, nn), but did not use model:float(), so there was still cuda tensor.

1 Like