I recently tried to save/load several parts (defined as nn.Sequential() ) of the model, like:
class MyModel(nn.Module): def __init__(self, ): super(MyModel, self).__init__() self.sequential_1 = nn.Sequential( ... ) ... self.sequential_N = nn.Sequential( ... ) def forward(self, x): x1 = self.sequential_1(x) ... z = self.sequential_N(xn) return z model = MyModel()
and save this like this:
torch.save(model.sequential_1.state_dict(), './data/sequential_1.pth') ... torch.save(model.sequential_N.state_dict(), './data/sequential_N.pth')
I’m not fully sure the save process above is valid, so I also saved my whole model as a backup additionally.
Once I got those .pth files, I tried to reload the previous status, like:
model_2 = MyModel() ## for checking model_2.sequential_1.load_state_dict(torch.load('./data/sequential_1.pth')) ... model_2.sequential_N.load_state_dict(torch.load('./data/sequential_N.pth'))
model_2 doesn’t work as I expected.
So, I had no choice and tried to reload the whole model in a conventional way, like:
model works as I expected. However, the weird problem starts from here.
model_2 because I wanted to double-check its behavior, but I saw this works PROPERLY now
My questions are followings:
- Is this an expected behavior?
- What is the proper way to save/load the part of the model with torch.save/load?
I am also trying to reproduce this with a small-sized toy example…