When I initialize a Parameter of a Module to be a scalar, and then load in a state_dict containing a 1D tensor, there is no error; instead, the Parameter takes on the first value of the 1D tensor without complaint. This behavior is a bit surprising to me, as I expected a RuntimeError. See the following code for an example; the assertion passes:
import torch
import torch.nn as nn
class SimpleModule(nn.Module):
def __init__(self, threshold_value):
super().__init__()
self.threshold = nn.Parameter(torch.tensor(threshold_value))
def forward(self, x):
return x
large_tensor = torch.randn(32000)
state_dict = {"threshold": large_tensor}
module = SimpleModule(0.0)
module.load_state_dict(state_dict)
assert module.threshold.item() == state_dict['threshold'][0].item()
I can reproduce your issue (on 2.9.0) using your example code.
This looks like a bug to me – you might consider filing a github issue. (I don’t see this
mentioned or documented anywhere.)
Note that torch.tensor (0.0) creates a pure-scalar “no-dimensions” tensor of shape torch.Size([]) (with value 0.0).
As far as I can tell, this bug – of not flagging the shape mismatch – only occurs when trying
to load a “regular” tensor into a torch.Size([]) tensor. Loading a torch.Size([]) tensor
into a “regular” tensor fails as expected with a size-mismatch error.
(In particular, loading a torch.Size([]) scalar into a torch.Size([1]) scalar fails.)