No error raised despite shape mismatch in load_state_dict

When I initialize a Parameter of a Module to be a scalar, and then load in a state_dict containing a 1D tensor, there is no error; instead, the Parameter takes on the first value of the 1D tensor without complaint. This behavior is a bit surprising to me, as I expected a RuntimeError. See the following code for an example; the assertion passes:

import torch
import torch.nn as nn


class SimpleModule(nn.Module):
    def __init__(self, threshold_value):
        super().__init__()
        self.threshold = nn.Parameter(torch.tensor(threshold_value))
    
    def forward(self, x):
        return x

large_tensor = torch.randn(32000)
state_dict = {"threshold": large_tensor}

module = SimpleModule(0.0)
module.load_state_dict(state_dict)
assert module.threshold.item() == state_dict['threshold'][0].item() 

Is this behavior intentional / documented?

Hi Michael!

I can reproduce your issue (on 2.9.0) using your example code.

This looks like a bug to me – you might consider filing a github issue. (I don’t see this
mentioned or documented anywhere.)

Note that torch.tensor (0.0) creates a pure-scalar “no-dimensions” tensor of shape
torch.Size([]) (with value 0.0).

As far as I can tell, this bug – of not flagging the shape mismatch – only occurs when trying
to load a “regular” tensor into a torch.Size([]) tensor. Loading a torch.Size([]) tensor
into a “regular” tensor fails as expected with a size-mismatch error.

(In particular, loading a torch.Size([]) scalar into a torch.Size([1]) scalar fails.)

Best.

K. Frank

Thanks for taking a look and confirming that it works on 2.9.0! I’ve opened an issue here.

2 Likes