How to load state_dict if saved as torch.save(model.state_dict(), fs)

I am testing torch.compile() with basic cifar10 and resnet, but cannot seem to figure out why I cannot load the saved state_dict. I am sure it is something fundamental that I am missing.
Might someone please be kind enough to point me in the right direction?

I am using: “python”: “3.12.3”, “torch”: “2.8.0+cu129”, “torchvision”: “0.23.0+cu129”

I saved the trained model with:

torch.save(mvl_model.state_dict(), fspec)

But when I try to load the saved model I get the following errors:

		with open(fspec_model_state, "rb") as fs:
			state_dict = torch.load(fs,
									map_location = self.device)
            model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
			model = torch.compile(model, mode="reduce-overhead")
			model.to(self.device)
			model.load_state_dict(state_dict)

E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] failed while attempting to run meta for aten.avg_pool2d.default
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] Traceback (most recent call last):
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py”, line 2717, in _dispatch_impl
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] r = func(*args, **kwargs)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] ^^^^^^^^^^^^^^^^^^^^^
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_ops.py”, line 829, in call
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] return self._op(*args, **kwargs)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_meta_registrations.py”, line 2939, in meta_avg_pool2d
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] pool2d_shape_check(
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/_meta_registrations.py”, line 4637, in pool2d_shape_check
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] torch._check(
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/init.py”, line 1684, in _check
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] _check_with(RuntimeError, cond, message)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] File “/venv/py312/lib/python3.12/site-packages/torch/init.py”, line 1666, in _check_with
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] raise error_type(message_evaluated)
E1128 10:37:06.355000 72103 torch/_subclasses/fake_tensor.py:2721] [0/0] RuntimeError: Given input size: (512x1x1). Calculated output size: (512x-5x-5). Output size is too small

from user code:
File “residual_net.py”, line 81, in forward
x = self.avgpool(x)

File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/pooling.py”, line 773, in forward
return F.avg_pool2d(

Training model creation:

model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode=“reduce-overhead”)
model = model.to(device)

File “residual_net.py”:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
	super(ResidualBlock, self).__init__()
	self.conv1 = nn.Sequential(
					nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
					nn.BatchNorm2d(out_channels),
					nn.ReLU())
	self.conv2 = nn.Sequential(
					nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
					nn.BatchNorm2d(out_channels))
	self.downsample = downsample
	self.relu = nn.ReLU()
	self.out_channels = out_channels

def forward(self, x):
	residual = x
	out = self.conv1(x)
	out = self.conv2(out)
	if self.downsample:
		residual = self.downsample(x)
	out += residual
	out = self.relu(out)
	return out 

class ResNet(nn.Module):
	def __init__(self, block, layers, num_classes = 10):
		super(ResNet, self).__init__()
		self.inplanes = 64
		self.conv1 = nn.Sequential(
						nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
						nn.BatchNorm2d(64),
						nn.ReLU())
		self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
		self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
		self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
		self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
		self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
		self.avgpool = nn.AvgPool2d(7, stride=1)
		self.fc = nn.Linear(512, num_classes)

def _make_layer(self, block, planes, blocks, stride=1):
	downsample = None
	if stride != 1 or self.inplanes != planes:
		downsample = nn.Sequential(
			nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
			nn.BatchNorm2d(planes),
		)
	layers = []
	layers.append(block(self.inplanes, planes, stride, downsample))
	self.inplanes = planes
	for i in range(1, blocks):
		layers.append(block(self.inplanes, planes))
	return nn.Sequential(*layers)

def forward(self, x):
	x = self.conv1(x) 
	x = self.maxpool(x)
	x = self.layer0(x)
	x = self.layer1(x)
	x = self.layer2(x)
	x = self.layer3(x)

	x = self.avgpool(x)        # <-----LINE 81, in forward ****
	x = x.view(x.size(0), -1)
	x = self.fc(x)
	return x

Did you try to serialize the state_dict of the original model as described here? CC @marksaroufim in case this workaround is not needed anymore.

Hi @ptrblck,

Thank you for the quick response.

Apologies. I was able to load the state_dict without the Missing key(s) or Unexpected key(s) errors into the compiled model with no messages.

However, after loading state_dict into it, I was not able to use the compiled model for inference. Obviously I am still missing something. I got a little lost in the serialization discussion at the link you provided.

The error I am getting when I try to use the loaded model and state_dict in inference is:

with open(fspec_state_dict, "rb") as sd:
	state_dict = torch.load(sd,	map_location = device)
	model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
	model = torch.compile(model, mode="reduce-overhead")
	model.to(device)
	model.load_state_dict(state_dict)

with torch.inference_mode():
    for imgs, labels in loader:
	    imgs = imgs.to(self.device, non_blocking=True)
	    labels = labels.to(self.device, non_blocking=True)
	    outputs = model(imgs).to(self.device)     #  Error Line #
TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function (*(FakeTensor(…, device=‘cuda:0’, size=(128, 512, 1, 1)), 7, 1, 0, False, True, None), **{}): got RuntimeError(‘Given input size: (512x1x1). Calculated output size: (512x-5x-5). Output size is too small’)

from user code:
File “residual_net.py”, line 81, in forward
x = self.avgpool(x)
File “/home/surfer5/venv/py312/lib/python3.12/site-packages/torch/nn/modules/pooling.py”, line 773, in forward
return F.avg_pool2d(

Might anyone have any hints on how to get this to inference?

Thanks for clarifying your question. The error message indicates a pooling layer would create an output with a negative size as the input is too small. Did you change the input size when trying to run the model for inference?

Hi @ptrblck.

Not that I am aware of, unless the mode=”reduce-overhead” compile option does that?

model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode="reduce-overhead")
model = model.to(device)

OR, perhaps if the following save command can cause that problem with a compiled model?

torch.save(model.state_dict(), fspec_state_dict)

Otherwise I am at a loss as to why compiling the model would cause that to occur when loading the state_dict into the compiled model.

I cannot reproduce any issues using:

device = "cuda"
model = ResNet(ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode="reduce-overhead")
model = model.to(device)
x = torch.randn(1, 3, 224, 224).to(device)

out = model(x)
print(out.shape)
# torch.Size([1, 10])


torch.save(model.state_dict(), "tmp.pt")

model = ResNet(ResidualBlock, [3,4,6,3])
model = torch.compile(model, mode="reduce-overhead")
model = model.to(device)

sd = torch.load("tmp.pt")
model.load_state_dict(sd)
out = model(x)
print(out.shape)
# torch.Size([1, 10])

Do you see the error using my code?
If not, could you post a minimal and executable code snippet reproducing the issue?

@ptrblck - no, your code works fine. Unfortunately the problem appears in the code below. It’s probably something fundamental I am missing, like wondering why my untied shoes keep falling off. But I am not seeing how to get the compiled model on to the GPU for inference.

1. Not compiling the model causes the "_orig_mod.” error.

2. Moving the compiled model to GPU causes the errors;

from user code:
File "residual_net.py”, line 74, in forward
x = self.conv1(x)
File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/container.py”, line 244, in forward
input = module(input)
File “/venv/py312/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py”, line 193, in forward
return F.batch_norm(

3. Not moving the compiled model to GPU works, but runs inference on the CPU

import test_dataloader_CIFAR10
import residual_net as rnn
import torch
import torch.nn as nn
import torch.optim as optim

device = "cuda"

augment = False,
batch_size =	 64
data_dir = "/tmp"
learning_rate = 0.08
momentum = 0.8
normalize_image = True
num_workers = 4
pin_memory = True
persistent_workers = True
prefetch_factor = 2
shuffle = True
weight_decay = 0.002
valid_size = 0.1

model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
loss_fn = nn.CrossEntropyLoss()

use_compiler = True
if use_compiler:
	model = torch.compile(model, mode="reduce-overhead")
	loss_fn = torch.compile(loss_fn, mode="reduce-overhead")

model = model.to(device)
loss_fn = loss_fn.to(device)
optimizer = optim.SGD(	model.parameters(),
						lr=learning_rate,
						weight_decay=weight_decay,
						momentum= momentum)

tdl = test_dataloader_CIFAR10.config_dataloaders()

num_epochs = 1
for epoch in range(num_epochs):

	train_loader, valid_loader = tdl.get_train_valid_loader(
							augment = augment,
							batch_size = batch_size,
							data_dir = data_dir,
							normalize_image = normalize_image,
							num_workers = num_workers,
							pin_memory = pin_memory,
							persistent_workers = persistent_workers,
							prefetch_factor = prefetch_factor,
							shuffle = shuffle,
							valid_size = valid_size)

	for images, labels in train_loader:

		# Move tensors to the configured device
		images = images.to(device, non_blocking=True)
		labels = labels.to(device, non_blocking=True)

		# Forward pass
		outputs = model(images).to(device)
		loss = loss_fn(outputs, labels).to(device)

		# Backward and optimize
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	del images, labels, outputs
	torch.save(model.state_dict(), f"epoch_{epoch}.pt")

print(f"\nTraining completed through epoch = {epoch}")

#*******************************************************
compiled_model_to_device = True
# True ->  ERROR
# False -> NO error, but runs inference on CPU.
#*******************************************************

print(f"\nStarting inference with compiled_model_to_device = {compiled_model_to_device}")

for epoch in range(num_epochs):
	
	fname = f"epoch_{epoch}.pt"

	model = rnn.ResNet(rnn.ResidualBlock, [3,4,6,3])
	model = torch.compile(model, mode="reduce-overhead")
	sd = torch.load(fname)
	model.load_state_dict(sd)

	#*** moving the compiled model to the GPU causes the error ***
	if compiled_model_to_device:
		model.to(device)

	print(f"\nLoaded state_dict: '{fname}'")

	loader =	 tdl.get_test_loader(
						augment = augment,
						batch_size = batch_size,
						data_dir = data_dir,
						normalize_image = normalize_image,
						num_workers = num_workers,
						pin_memory = pin_memory,
						shuffle = shuffle)	

	for images, labels in loader:
		if compiled_model_to_device:
			out = model(images).to(device)
		else:
			out = model(images)
		
	print(f"\nout.shape = {out.shape}")
	print(f"\nInferenced using state_dict: '{fname}'")

	del images, labels

Can anyone give some pointers on how to get the compiled model onto the GPU for inference?

Hi @lostalot so you can’t really serialize a compiled model because torch.compile is a jit compiler

If you want to avoid state dict shenanigans though I’d recommend you do model.compile() instead of torch.compile(model)

Thanks @marksaroufim. Might there be a way to recover trained model for inference on the GPU, say for the above example?

A .cuda() call should just work but to be clear what are you trying to do exactly? if you’re trying to reduce compilation times you should be using the compiler cache Compile Time Caching in torch.compile — PyTorch Tutorials 2.9.0+cu128 documentation

A specific case: Assume the state_dict of the model using the cifar10 data in the code above was saved. I would like to use that saved model in inference on the GPU, rather than retrain the model using model.compile(). I don’t know where or how to begin to do that.
As this is my first attempt using torch.compile() I am sure it is something basic. But at the moment I am not sure where to start. Any direction you could give would be a big help.

SOLVED. And this was good education in torch.compile. (thank you @marksaroufim )

Going to the references posted by @marksoufim it (eventually) became clear that an error message of: “Given input size: (512x1x1). Calculated output size: (512x-5x-5)” was most likely a problem in the data shape. I went through data loading code and discovered it. I am not sure when I messed it up, but I most likely fat fingered the keyboard. Though, I can’t rule out a deep seated desire for problems that I don’t know about - but that is for a different kind of media.
Thank you for your patience and help.