Greetings,
I recently tried to save/load several parts (defined as nn.Sequential() ) of the model, like:
class MyModel(nn.Module):
def __init__(self, ):
super(MyModel, self).__init__()
self.sequential_1 = nn.Sequential( ... )
...
self.sequential_N = nn.Sequential( ... )
def forward(self, x):
x1 = self.sequential_1(x)
...
z = self.sequential_N(xn)
return z
model = MyModel()
and save this like this:
torch.save(model.sequential_1.state_dict(), './data/sequential_1.pth')
...
torch.save(model.sequential_N.state_dict(), './data/sequential_N.pth')
I’m not fully sure the save process above is valid, so I also saved my whole model as a backup additionally.
torch.save(model.state_dict(), './data/model.pth')
Once I got those .pth files, I tried to reload the previous status, like:
model_2 = MyModel() ## for checking
model_2.sequential_1.load_state_dict(torch.load('./data/sequential_1.pth'))
...
model_2.sequential_N.load_state_dict(torch.load('./data/sequential_N.pth'))
but this model_2
doesn’t work as I expected.
So, I had no choice and tried to reload the whole model in a conventional way, like:
model.load_state_dict(torch.load('./data/model.pth'))
This model
works as I expected. However, the weird problem starts from here.
I ran model_2
because I wanted to double-check its behavior, but I saw this works PROPERLY now
My questions are followings:
- Is this an expected behavior?
- What is the proper way to save/load the part of the model with torch.save/load?
I am also trying to reproduce this with a small-sized toy example…