Hello together,
i have a ReLU-NN class as follows
20 class ReLU_NN(tr.nn.Module):
21 '''
22 Class for a ReLU-NN with variable size.
23 Input:
24 -nn_list = [input dim., first hidden layer size,...,
25 last hidden layer size, output dim.]
26
27 '''
28 def __init__(self, nn_list):
29 super(ReLU_NN, self).__init__()
30 self.nn_list = nn_list
31 self.hidden = tr.nn.ModuleList()
32 for i in range(len(nn_list)-1):
33 self.hidden.append(tr.nn.Linear(nn_list[i], nn_list[i+1]).double())
and now i am trying to add some save methods:
332 def save_state_dict(self, PATH):
333 '''
334 A common PyTorch convention is to save models using either a .pt or
335 .pth file extension.
336 '''
337 tr.save(self.state_dict(), PATH)
338
339
340 @classmethod
341 def load_state_dict(cls, PATH):
342 '''
343
344 '''
345 state_dict = tr.load(PATH)
346 # Attention: here is assumed that all layers have bias
347 nn_list = [None]*(len(state_dict)//2+1)
348 nn_list[0] = state_dict['hidden.0.weight'].shape[1]
349 for i in range(len(state_dict)//2):
350 nn_list[i+1] = state_dict[f'hidden.{i}.weight'].shape[0]
351
352 model = cls(nn_list)
353 print(model.nn_list)
354 model.load_state_dict(tr.load(PATH))
355
356 return model
I call the methods as follows:
76 nn_list = [P, neurons1, 1]
77 model = ReLU_NN(nn_list)
78 model.initialize_uniform(u=(-0.1,0.1), c=0.1)
79
80 PATH = '/home/lewin/01_studium/6_Masterarbeit/01_Programme/03_test_algorithms/testmodel.pt'
81 print(isinstance(PATH,str))
82 model.save_state_dict(PATH)
83 print(model.nn_list)
84 model2 = ReLU_NN.load_state_dict(PATH)
85 print(model2.state_dict())
This gives the output:
True
[4, 32, 1]
[4, 32, 1]
Traceback (most recent call last):
File "/home/lewin/.local/lib/python3.8/site-packages/torch/serialization.py", line 311, in _check_seekable
f.seek(f.tell())
AttributeError: 'collections.OrderedDict' object has no attribute 'seek'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "04_greedy_nn.py", line 84, in <module>
model2 = ReLU_NN.load_state_dict(PATH)
File "/home/lewin/01_studium/6_Masterarbeit/01_Programme/aux_files/aux_PyTorch.py", line 354, in load_state_dict
model.load_state_dict(tr.load(PATH))
File "/home/lewin/01_studium/6_Masterarbeit/01_Programme/aux_files/aux_PyTorch.py", line 345, in load_state_dict
state_dict = tr.load(PATH)
File "/home/lewin/.local/lib/python3.8/site-packages/torch/serialization.py", line 584, in load
with _open_file_like(f, 'rb') as opened_file:
File "/home/lewin/.local/lib/python3.8/site-packages/torch/serialization.py", line 239, in _open_file_like
return _open_buffer_reader(name_or_buffer)
File "/home/lewin/.local/lib/python3.8/site-packages/torch/serialization.py", line 224, in __init__
_check_seekable(buffer)
File "/home/lewin/.local/lib/python3.8/site-packages/torch/serialization.py", line 314, in _check_seekable
raise_err_msg(["seek", "tell"], e)
File "/home/lewin/.local/lib/python3.8/site-packages/torch/serialization.py", line 307, in raise_err_msg
raise type(e)(msg)
AttributeError: 'collections.OrderedDict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
The first True comes from line 81 and the nn_lists do coincide. I think the solution to this problem looks something like Trying to load a torch model via Dropbox, but i don’t get it.