State dict and loading of buffers

I’m training a simple vision model and for some reason loading saved buffers does not work. Other weights seem to load just fine. Not sure what could be the reason. Any suggestions are welcome

In [21]: pretrained_dict['layer1.0.relu1.running_max']
Out[21]: tensor(2.8699, device='cuda:1')

In [22]: net.layer1[0].relu1.running_max
Out[22]: tensor(2)

In [23]: net.load_state_dict(pretrained_dict)
Out[23]: <All keys matched successfully>

In [24]: net.layer1[0].relu1.running_max
Out[24]: tensor(2)

Could you post a minimal and executable code snippet reproducing the issue?

I have resolved my issue, for those who might face a similar issue I’m providing the solution and the reason for this problem.
Here’s roughly my code (doing similar stuff to batch norm, except that it is batch max)

def Module(nn.Module):
    def __init__(self, op_mode):
        super().__init__()
        self.register_buffer('running_max', torch.tensor(1)) # the problem is here it should have been 1.0
    def forward(x):
         xmax = torch.max(torch.abs(x))
         self.running_max = (1 - self.momentum) * self.running_max + self.momentum * xmax

The reason why the loading is not working is that the buffer is initially assigned to be an int as opposed to float32 that’s why it is silently ignored / unpromoted to float32

The

@ptrblck do you think this deserves an issue in torch github issue tracker?

I would assume this is expected behavior since you are explicitly defining running_max as an integer type via:

torch.tensor(1).type()
# 'torch.LongTensor'

Copying data into the LongTensor will transform it instead of changing the dtype of the target tensor:

x = torch.tensor(1)
x.copy_(2.3)
print(x)
# tensor(2)

which I would also expect.

Shouldn’t there be at least some sort of warnings when loading the weights with different types instead of silently promoting them? This really threw me off:

In [23]: net.load_state_dict(pretrained_dict)
Out[23]: <All keys matched successfully>

Shouldn’t it output like All keys matched, some keys had types promoted or something among those lines? Spotting the difference between 1 and 1.0 I guess is something you get used to do, but man this took a while to diagnose.

It might be a good idea to add a warning and you could propose this as a feature on GitHub.
If you want to use the metadata of the state_dict, you could use m.load_state_dict(sd, assign=True).

Thanks a lot I raised an issue Inform when weights loading promotes types · Issue #126000 · pytorch/pytorch · GitHub
Could you suggest of a way to implement that? In case this featured deemed desired I’d like to try sending PR if it is not that difficult