Hi there! I want to know if somone could help me:
I have a pretrained linear encoder that i would like to add before my real model but I dont know how to do it.
Let me explain a little bit better:
- I have trained a encoder NN with my dataset and I have saved the parameters with torch.save(model.state_dict(), ‘guada_withvalid_DNN.pt’).
- Then in another ‘.py’ I have the new scenario with a model defined like this:
# neural network architecture definition class WifiRNN(nn.Module): def __init__(self, i_size, h_size, n_layers, num_classes): super(WifiRNN, self).__init__() self.input_size = i_size self.hidden_size = h_size self.num_layers = n_layers self.num_classes = num_classes self.wifi_rnn = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True) self.out = nn.Linear(in_features=self.hidden_size * sequence_length, out_features=self.num_classes) def forward(self, x_in, h_state): r_out, h_state = self.wifi_rnn(x_in, h_state) r_out = r_out.reshape(r_out.shape, -1) out = self.out(r_out) return out, h_state def init_hidden_state(self, b_size): h0 = torch.zeros(self.num_layers, b_size, self.hidden_size).to(device) return h0
- I want to import the pretrained encoder into my new escenario to add it before RNN so first I feed the data to the pretrained encoder with frozen parameters and then I feed the RNN and aply backpropagation.
I’ve been trying to do:
model_encoder = nn.Module # model_encoder = torch.load('guada_withvalid_DNN.pt') model_encoder.load_state_dict(torch.load('guada_withvalid_DNN.pt'))
however i get:
Traceback (most recent call last): File "/snap/pycharm-community/232/plugins/python-ce/helpers/pydev/pydevd.py", line 1477, in _exec pydev_imports.execfile(file, globals, locals) # execute the script File "/snap/pycharm-community/232/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/home/lauram/PycharmProjects/RNN/RNN_preEncoder.py", line 146, in <module> model_encoder.load_state_dict(torch.load('guada_withvalid_DNN')) TypeError: load_state_dict() missing 1 required positional argument: 'state_dict' python-BaseException
Can anybody help me? The only examples I have seen are with pretrained models from torchvision and that’s not what I am trying to do.