I am testing torch.compile() with basic cifar10 and resnet, but cannot seem to figure out why I cannot load the saved state_dict. I am sure it is something fundamental that I am missing.
Might someone please be kind enough to point me in the right direction?
I am using: “python”: “3.12.3”, “torch”: “2.8.0+cu129”, “torchvision”: “0.23.0+cu129”
I saved the trained model with:
torch.save(mvl_model.state_dict(), fspec)
But when I try to load the saved model I get the following errors:
with open(fspec_model_state, "rb") as fs: state_dict = torch.load(fs, map_location = self.device) model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3]) model = torch.compile(model, mode="reduce-overhead") model.to(self.device) model.load_state_dict(state_dict)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] failed while attempting to run meta for aten.avg_pool2d.default
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] Traceback (most recent call last):
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py”, line 2717, in _dispatch_impl
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] r = func(*args, **kwargs)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] ^^^^^^^^^^^^^^^^^^^^^
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_ops.py”, line 829, in call
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] return self._op(*args, **kwargs)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_meta_registrations.py”, line 2939, in meta_avg_pool2d
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] pool2d_shape_check(
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_meta_registrations.py”, line 4637, in pool2d_shape_check
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] torch._check(
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/init.py”, line 1684, in _check
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] _check_with(RuntimeError, cond, message)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/init.py”, line 1666, in _check_with
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] raise error_type(message_evaluated)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] RuntimeError: Given input size: (512x1x1). Calculated output size: (512x-5x-5). Output size is too smallfrom user code:
File “residual_net.py”, line 81, in forward
x = self.avgpool(x)File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/pooling.py”, line 773, in forward
return F.avg_pool2d(
Training model creation:
model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode=“reduce-overhead”)
model = model.to(device)
File “residual_net.py”:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU()
self.out_channels = out_channels
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes = 10):
super(ResNet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
nn.BatchNorm2d(64),
nn.ReLU())
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
nn.BatchNorm2d(planes),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x) # <-----LINE 81, in forward ****
x = x.view(x.size(0), -1)
x = self.fc(x)
return x