Hi there! I want to know if somone could help me:
I have a pretrained linear encoder that i would like to add before my real model but I dont know how to do it.
Let me explain a little bit better:
- I have trained a encoder NN with my dataset and I have saved the parameters with torch.save(model.state_dict(), ‘guada_withvalid_DNN.pt’).
- Then in another ‘.py’ I have the new scenario with a model defined like this:
# neural network architecture definition
class WifiRNN(nn.Module):
def __init__(self, i_size, h_size, n_layers, num_classes):
super(WifiRNN, self).__init__()
self.input_size = i_size
self.hidden_size = h_size
self.num_layers = n_layers
self.num_classes = num_classes
self.wifi_rnn = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True)
self.out = nn.Linear(in_features=self.hidden_size * sequence_length, out_features=self.num_classes)
def forward(self, x_in, h_state):
r_out, h_state = self.wifi_rnn(x_in, h_state)
r_out = r_out.reshape(r_out.shape[0], -1)
out = self.out(r_out)
return out, h_state
def init_hidden_state(self, b_size):
h0 = torch.zeros(self.num_layers, b_size, self.hidden_size).to(device)
return h0
- I want to import the pretrained encoder into my new escenario to add it before RNN so first I feed the data to the pretrained encoder with frozen parameters and then I feed the RNN and aply backpropagation.
I’ve been trying to do:
model_encoder = nn.Module
# model_encoder = torch.load('guada_withvalid_DNN.pt')
model_encoder.load_state_dict(torch.load('guada_withvalid_DNN.pt'))
however i get:
Traceback (most recent call last):
File "/snap/pycharm-community/232/plugins/python-ce/helpers/pydev/pydevd.py", line 1477, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/snap/pycharm-community/232/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/lauram/PycharmProjects/RNN/RNN_preEncoder.py", line 146, in <module>
model_encoder.load_state_dict(torch.load('guada_withvalid_DNN'))
TypeError: load_state_dict() missing 1 required positional argument: 'state_dict'
python-BaseException
Can anybody help me? The only examples I have seen are with pretrained models from torchvision and that’s not what I am trying to do.
Thanks!