Net.load_state_dict() Broken in Windows but works on WSL

first_net = Net()

torch.save(net.state_dict(), 'blank_test.pth')

second_net = Net()

second_net.load_state_dict(torch.load('blank_test.pth'))

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[22], line 7
      3 torch.save(net.state_dict(), 'blank_test.pth')
      5 second_net = Net()
----> 7 second_net.load_state_dict(torch.load('blank_test.pth'))

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\torch\nn\modules\module.py:2153, in Module.load_state_dict(self, state_dict, strict, assign)
   2148         error_msgs.insert(
   2149             0, 'Missing key(s) in state_dict: {}. '.format(
   2150                 ', '.join(f'"{k}"' for k in missing_keys)))
   2152 if len(error_msgs) > 0:
-> 2153     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2154                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2155 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Net:
	size mismatch for lif1.mem: copying a param with shape torch.Size([16, 1000]) from checkpoint, the shape in current model is torch.Size([1]).
	size mismatch for lif2.mem: copying a param with shape torch.Size([16, 10]) from checkpoint, the shape in current model is torch.Size([1]).

I have been running this notebook in WSL on one system for quite sometime and it works fine but when I go to deploy it on my larger windows rig for model evaluation, it completely dies.

Did you check the shapes of lift1/2.mem as these attributes seem to cause the issue?

The shape should be the exact same given I implemented two the the same Net() class. Attached below is my class definition:

# Main Setup

# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.70

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

I understand that the shape should be equal, but for some reason the error points to the mismatch, so could you print the shape explicitly right before trying to load the state_dict to double check?