I saw two methods of accessing a model’s parameters:
Which is more correct?
What are the differences?
Thanks
I saw two methods of accessing a model’s parameters:
Which is more correct?
What are the differences?
Thanks
model.parameters() contains all the learnable parameters of the model. state_dict() is a python dict mapping from layer to their parameters. Optimizers have their own state_dict. As state_dict are python dicts they are easy to save and load. docs
Additionally to @Kushaj’s description, the state_dict
holds all buffers
besides the learnable parameters, e.g. BatchNorm
's running estimates, so that you can recreate your model properly.
I think should be added to docs as well.
EDIT: It is in the docs my bad.
Well, in the docs you’ve posted this sentence might be misleading:
Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) have entries in the model’s state_dict.
So I see your point here.
Have a look at this simple model:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(6)
self.register_buffer('buf1', torch.randn(1))
def forward(self, x):
x = self.bn1(self.conv1(x))
return x
model = MyModel()
print(model.state_dict()['buf1'])
As you can see buf1
is also in the state_dict
.
Could you create an issue on GitHub and propose a change?
If not, let me know and I’ll do it.
I will post the issue.
Another difference appears to be that for all learnable parameters which appear in both model.state_dict()
and model.parameters()
, requires_grad
is True
in the model.parameters()
output and False
in the model.state_dict()
output.
Why is this?
That’s true, not sure why exactly but after reloading my model using load_state_dict as -
model.load_state_dict(torch.load(PATH))
It does everything as expected. Sets all the parameters respectively with requires_grad
True.
So my guess is in the state_dict the requires_grad
is False in order to avoid any unnecessary autograd graph creation if the user does some computation to the tensors directly from state_dict.
I asked on Stack Overflow and got this answer.
If something should be added or subtracted from it, please let me know.
Otherwise, I will accept it here as well.