Does torch.load() affect another model?

Greetings,
I recently tried to save/load several parts (defined as nn.Sequential() ) of the model, like:

class MyModel(nn.Module):
    def __init__(self, ):
        super(MyModel, self).__init__()
        self.sequential_1 = nn.Sequential( ... )
        ...
        self.sequential_N = nn.Sequential( ... )
        
    def forward(self, x):
        x1 = self.sequential_1(x)
        ...
        z = self.sequential_N(xn)
        return z

model = MyModel()

and save this like this:

torch.save(model.sequential_1.state_dict(), './data/sequential_1.pth')
...
torch.save(model.sequential_N.state_dict(), './data/sequential_N.pth')

I’m not fully sure the save process above is valid, so I also saved my whole model as a backup additionally.

torch.save(model.state_dict(), './data/model.pth')

Once I got those .pth files, I tried to reload the previous status, like:

model_2 = MyModel() ## for checking
model_2.sequential_1.load_state_dict(torch.load('./data/sequential_1.pth'))
...
model_2.sequential_N.load_state_dict(torch.load('./data/sequential_N.pth'))

but this model_2 doesn’t work as I expected.
So, I had no choice and tried to reload the whole model in a conventional way, like:

model.load_state_dict(torch.load('./data/model.pth'))

This model works as I expected. However, the weird problem starts from here.
I ran model_2 because I wanted to double-check its behavior, but I saw this works PROPERLY now :confused:

My questions are followings:

  1. Is this an expected behavior?
  2. What is the proper way to save/load the part of the model with torch.save/load?

I am also trying to reproduce this with a small-sized toy example…

It’s probably about optimizer’s state dict rather than model’s. Unless you forgot to save some parameter not contained in sequential_N.

Realize some optimizers compute online statistics which are loss. If you train to resume training but lost those statistics you will see a peak in the loss plot.

Hello Juan,

I only saved the model’s information torch.save(model.state_dict(), ...) in the first session, and there is no optimizer in the second (load) session since I just evaluate the model.

ok… I figured out what was happening. Basically, most problems came from my misusing.

In my model definition, some parts were already instantiated model, like:

class Sequential_X(nn.Module):
        ...

sequential_x = Sequential_X()
sequential_x.requires_grad = False ## instantiated in the global scope
...

class MyModel(nn.Module):
    def __init__(self, ):
        super(MyModel, self).__init__()
        ...
        self.sequential_x = sequential_x ## <- it was the source of problem :(
   ...

I thought that sequential_x was fixed by sequential_x.requires_grad=False, but it was trained anyway (maybe I forgot to set sequential_x.requires_grad=False before the training session and changed the code later?) so I only saved/loaded some modules that I thought trained at this session. Therefore, the first result was not that I expected.

After this, I loaded the whole model by model.load_state_dict(torch.load('./data/model.pth')) , it also updates the globally defined sequential_x , so both model and model_2 work fine.