I got a problem when I want to load my trained models Therefore I created me a simple example to find out what the problem of my save and load method is.
Here you can see the file where I save my model:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class MeinNetz(nn.Module):
def __init__(self):
super(MeinNetz, self).__init__()
self.lin1 = nn.Linear(10, 10)
self.lin2 = nn.Linear(10, 10)
def forward(self, x):
x = F.relu(self.lin1(x))
x = (self.lin2(x))
return x
netz = MeinNetz()
input = Variable(torch.ones(1, 10))
torch.save(netz.state_dict(), './net.pth')
output = netz(input)
print(output)
And here is the file where I want to load my model:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class MeinNetz(nn.Module):
def __init__(self):
super(MeinNetz, self).__init__()
self.lin1 = nn.Linear(10, 10)
self.lin2 = nn.Linear(10, 10)
def forward(self, x):
x = F.relu(self.lin1(x))
x = (self.lin2(x))
return x
netz2 = MeinNetz
netz2.load_state_dict(torch.load('./net.pth'))
input = Variable(torch.ones(1,10))
output = netz2(input)
print(output)
When I run the load file I get the following error :
load_state_dict() missing 1 required positional argument: 'state_dict'
What is the problem here?
Sry im stupid I missed brackets. Problem of the simple example is solved