How to load state_dict if saved as torch.save(model.state_dict(), fs)

I am testing torch.compile() with basic cifar10 and resnet, but cannot seem to figure out why I cannot load the saved state_dict. I am sure it is something fundamental that I am missing.
Might someone please be kind enough to point me in the right direction?

I am using: “python”: “3.12.3”, “torch”: “2.8.0+cu129”, “torchvision”: “0.23.0+cu129”

I saved the trained model with:

torch.save(mvl_model.state_dict(), fspec)

But when I try to load the saved model I get the following errors:

		with open(fspec_model_state, "rb") as fs:
			state_dict = torch.load(fs,
									map_location = self.device)
            model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
			model = torch.compile(model, mode="reduce-overhead")
			model.to(self.device)
			model.load_state_dict(state_dict)

E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] failed while attempting to run meta for aten.avg_pool2d.default
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] Traceback (most recent call last):
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py”, line 2717, in _dispatch_impl
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] r = func(*args, **kwargs)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] ^^^^^^^^^^^^^^^^^^^^^
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_ops.py”, line 829, in call
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] return self._op(*args, **kwargs)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_meta_registrations.py”, line 2939, in meta_avg_pool2d
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] pool2d_shape_check(
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_meta_registrations.py”, line 4637, in pool2d_shape_check
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] torch._check(
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/init.py”, line 1684, in _check
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] _check_with(RuntimeError, cond, message)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/init.py”, line 1666, in _check_with
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] raise error_type(message_evaluated)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] RuntimeError: Given input size: (512x1x1). Calculated output size: (512x-5x-5). Output size is too small

from user code:
File “residual_net.py”, line 81, in forward
x = self.avgpool(x)

File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/pooling.py”, line 773, in forward
return F.avg_pool2d(

Training model creation:

model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode=“reduce-overhead”)
model = model.to(device)

File “residual_net.py”:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
	super(ResidualBlock, self).__init__()
	self.conv1 = nn.Sequential(
					nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
					nn.BatchNorm2d(out_channels),
					nn.ReLU())
	self.conv2 = nn.Sequential(
					nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
					nn.BatchNorm2d(out_channels))
	self.downsample = downsample
	self.relu = nn.ReLU()
	self.out_channels = out_channels

def forward(self, x):
	residual = x
	out = self.conv1(x)
	out = self.conv2(out)
	if self.downsample:
		residual = self.downsample(x)
	out += residual
	out = self.relu(out)
	return out 

class ResNet(nn.Module):
	def __init__(self, block, layers, num_classes = 10):
		super(ResNet, self).__init__()
		self.inplanes = 64
		self.conv1 = nn.Sequential(
						nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
						nn.BatchNorm2d(64),
						nn.ReLU())
		self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
		self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
		self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
		self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
		self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
		self.avgpool = nn.AvgPool2d(7, stride=1)
		self.fc = nn.Linear(512, num_classes)

def _make_layer(self, block, planes, blocks, stride=1):
	downsample = None
	if stride != 1 or self.inplanes != planes:
		downsample = nn.Sequential(
			nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
			nn.BatchNorm2d(planes),
		)
	layers = []
	layers.append(block(self.inplanes, planes, stride, downsample))
	self.inplanes = planes
	for i in range(1, blocks):
		layers.append(block(self.inplanes, planes))
	return nn.Sequential(*layers)

def forward(self, x):
	x = self.conv1(x) 
	x = self.maxpool(x)
	x = self.layer0(x)
	x = self.layer1(x)
	x = self.layer2(x)
	x = self.layer3(x)

	x = self.avgpool(x)        # <-----LINE 81, in forward ****
	x = x.view(x.size(0), -1)
	x = self.fc(x)
	return x

Did you try to serialize the state_dict of the original model as described here? CC @marksaroufim in case this workaround is not needed anymore.

Hi @ptrblck,

Thank you for the quick response.

Apologies. I was able to load the state_dict without the Missing key(s) or Unexpected key(s) errors into the compiled model with no messages.

However, after loading state_dict into it, I was not able to use the compiled model for inference. Obviously I am still missing something. I got a little lost in the serialization discussion at the link you provided.

The error I am getting when I try to use the loaded model and state_dict in inference is:

with open(fspec_state_dict, "rb") as sd:
	state_dict = torch.load(sd,	map_location = device)
	model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
	model = torch.compile(model, mode="reduce-overhead")
	model.to(device)
	model.load_state_dict(state_dict)

with torch.inference_mode():
    for imgs, labels in loader:
	    imgs = imgs.to(self.device, non_blocking=True)
	    labels = labels.to(self.device, non_blocking=True)
	    outputs = model(imgs).to(self.device)     #  Error Line #
TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function (*(FakeTensor(…, device=‘cuda:0’, size=(128, 512, 1, 1)), 7, 1, 0, False, True, None), **{}): got RuntimeError(‘Given input size: (512x1x1). Calculated output size: (512x-5x-5). Output size is too small’)

from user code:
File “residual_net.py”, line 81, in forward
x = self.avgpool(x)
File “/home/surfer5/venv/py312/lib/python3.12/site-packages/torch/nn/modules/pooling.py”, line 773, in forward
return F.avg_pool2d(

Might anyone have any hints on how to get this to inference?

Thanks for clarifying your question. The error message indicates a pooling layer would create an output with a negative size as the input is too small. Did you change the input size when trying to run the model for inference?

Hi @ptrblck.

Not that I am aware of, unless the mode=”reduce-overhead” compile option does that?

model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode="reduce-overhead")
model = model.to(device)

OR, perhaps if the following save command can cause that problem with a compiled model?

torch.save(model.state_dict(), fspec_state_dict)

Otherwise I am at a loss as to why compiling the model would cause that to occur when loading the state_dict into the compiled model.

I cannot reproduce any issues using:

device = "cuda"
model = ResNet(ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode="reduce-overhead")
model = model.to(device)
x = torch.randn(1, 3, 224, 224).to(device)

out = model(x)
print(out.shape)
# torch.Size([1, 10])


torch.save(model.state_dict(), "tmp.pt")

model = ResNet(ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode="reduce-overhead")
model = model.to(device)

sd = torch.load("tmp.pt")
model.load_state_dict(sd)
out = model(x)
print(out.shape)
# torch.Size([1, 10])

Do you see the error using my code?
If not, could you post a minimal and executable code snippet reproducing the issue?

@ptrblck - no, your code works fine. Unfortunately the problem appears in the code below. It’s probably something fundamental I am missing, like wondering why my untied shoes keep falling off. But I am not seeing how to get the compiled model on to the GPU for inference.

1. Not compiling the model causes the "_orig_mod.” error.

2. Moving the compiled model to GPU causes the errors;

from user code:
File "residual_net.py”, line 74, in forward
x = self.conv1(x)
File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/container.py”, line 244, in forward
input = module(input)
File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py”, line 193, in forward
return F.batch_norm(

3. Not moving the compiled model to GPU works, but runs inference on the CPU

import test_dataloader_CIFAR10
import residual_net as rnn
import torch
import torch.nn as nn
import torch.optim as optim

device = "cuda"

augment = False,
batch_size =	 64
data_dir = "/tmp"
learning_rate = 0.08
momentum = 0.8
normalize_image = True
num_workers = 4
pin_memory = True
persistent_workers = True
prefetch_factor = 2
shuffle = True
weight_decay = 0.002
valid_size = 0.1

model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
loss_fn = nn.CrossEntropyLoss()

use_compiler = True
if use_compiler:
	model = torch.compile(model, mode="reduce-overhead")
	loss_fn = torch.compile(loss_fn, mode="reduce-overhead")

model = model.to(device)
loss_fn = loss_fn.to(device)
optimizer = optim.SGD(	model.parameters(),
						lr=learning_rate,
						weight_decay=weight_decay,
						momentum= momentum)

tdl = test_dataloader_CIFAR10.config_dataloaders()

num_epochs = 1
for epoch in range(num_epochs):

	train_loader, valid_loader = tdl.get_train_valid_loader(
							augment = augment,
							batch_size = batch_size,
							data_dir = data_dir,
							normalize_image = normalize_image,
							num_workers = num_workers,
							pin_memory = pin_memory,
							persistent_workers = persistent_workers,
							prefetch_factor = prefetch_factor,
							shuffle = shuffle,
							valid_size = valid_size)

	for images, labels in train_loader:

		# Move tensors to the configured device
		images = images.to(device, non_blocking=True)
		labels = labels.to(device, non_blocking=True)

		# Forward pass
		outputs = model(images).to(device)
		loss = loss_fn(outputs, labels).to(device)

		# Backward and optimize
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	del images, labels, outputs
	torch.save(model.state_dict(), f"epoch_{epoch}.pt")

print(f"\nTraining completed through epoch = {epoch}")

#*******************************************************
compiled_model_to_device = True
# True ->  ERROR
# False -> NO error, but runs inference on CPU.
#*******************************************************

print(f"\nStarting inference with compiled_model_to_device = {compiled_model_to_device}")

for epoch in range(num_epochs):
	
	fname = f"epoch_{epoch}.pt"

	model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
	model = torch.compile(model, mode="reduce-overhead")
	sd = torch.load(fname)
	model.load_state_dict(sd)

	#*** moving the compiled model to the GPU causes the error ***
	if compiled_model_to_device:
		model.to(device)

	print(f"\nLoaded state_dict: '{fname}'")

	loader =	 tdl.get_test_loader(
						augment = augment,
						batch_size = batch_size,
						data_dir = data_dir,
						normalize_image = normalize_image,
						num_workers = num_workers,
						pin_memory = pin_memory,
						shuffle = shuffle)	

	for images, labels in loader:
		if compiled_model_to_device:
			out = model(images).to(device)
		else:
			out = model(images)
		
	print(f"\nout.shape = {out.shape}")
	print(f"\nInferenced using state_dict: '{fname}'")

	del images, labels