Expected all tensors on same device, but found at least two devices, cuda:0 and cpu!

Shown below is the error I receive when attempting to run this setup code block in Jupyter:


RuntimeError                              Traceback (most recent call last)
Cell In[5], line 136
    131 data = data.to(device)
    132 targets = targets.to(device)
--> 136 spk_rec, mem_rec = net(data.view(batch_size, -1))
    137 # print(mem_rec.size())
    138 
    139 # initialize the total loss value
    140 loss_val = torch.zeros((1), dtype=dtype, device=device)

File c:\Users\brooks\AppData\Local\anaconda3\envs\snntorch\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File c:\Users\brooks\AppData\Local\anaconda3\envs\snntorch\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
...
--> 105     mem_shift = mem - self.threshold
    106     reset = self.spike_grad(mem_shift).clone().detach()
    108     return reset

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

This is my setup code block that triggers this error. I have attempted checking the variables that are tensors and manually setting them all to cuda and it has not worked.

# Main Setup

# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device("cuda")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.70

V1 = 0.5 # shared recurrent connection
V2 = torch.rand(num_outputs) # unshared recurrent connections

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)

        # Default RLeaky Layer where recurrent connections
        # are initialized using PyTorch defaults in nn.Linear.
        self.lif1 = snn.RLeaky(beta=beta,
                    linear_features=num_hidden)

        self.fc2 = nn.Linear(num_hidden, num_outputs)

        # each neuron has a single connection back to itself
        # where the output spike is scaled by V.
        # For `all_to_all = False`, V can be shared between
        # neurons (e.g., V1) or unique / unshared between
        # neurons (e.g., V2).
        # V is learnable by default.
        self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1)

    def forward(self, x):
        # Initialize hidden states at t=0
        spk1, mem1 = self.lif1.init_rleaky()
        spk2, mem2 = self.lif2.init_rleaky()

        # Record output layer spikes and membrane
        spk2_rec = []
        mem2_rec = []

        # time-loop
        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, spk1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, spk2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        # convert lists to tensors
        spk2_rec = torch.stack(spk2_rec)
        mem2_rec = torch.stack(mem2_rec)

        return spk2_rec, mem2_rec

# Load the network onto CUDA if available
net = Net().to(device)

# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(
    data, targets, epoch,
    counter, iter_counter,
        loss_hist, test_loss_hist, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)



spk_rec, mem_rec = net(data.view(batch_size, -1))
# print(mem_rec.size())

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

# print(f"Training loss: {loss_val.item():.3f}")
# print_batch_accuracy(data, targets, train=True)

# clear previously stored gradients
optimizer.zero_grad()

# calculate the gradients
loss_val.backward()

# weight update
optimizer.step()

# calculate new network outputs using the same data
spk_rec, mem_rec = net(data.view(batch_size, -1))

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

# print(f"Training loss: {loss_val.item():.3f}")
# print_batch_accuracy(data, targets, train=True)

Check what self.threshold is and if it’s a tensor make sure it was moved to the same device as the model inputs.

Found the issue. I was running the newest version of snnTorch which totally broke it. Reverted back a couple version to the last stable run I had on my other machine and now it’s working fine. For anyone wondering, snnTorch 0.7.0. I’ll raise an issue on the snnTorch repo. Also, below is my pip requirements text file for every version of package in case anyone wants to use it.

anyio==4.2.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
blinker==1.4
Brian2==2.6.0
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
comm==0.2.1
command-not-found==0.3
contourpy==1.2.0
cryptography==3.4.8
cycler==0.12.1
Cython==3.0.9
dbus-python==1.2.18
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
distro==1.7.0
distro-info==1.1+ubuntu0.2
exceptiongroup==1.2.0
executing==2.0.1
fastjsonschema==2.19.1
filelock==3.9.0
fonttools==4.48.1
fqdn==1.5.1
fsspec==2023.4.0
h11==0.14.0
h5py==3.10.0
httpcore==1.0.2
httplib2==0.20.2
httpx==0.26.0
idna==3.6
importlib-metadata==4.6.4
ipykernel==6.29.2
ipython==8.21.0
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.19.1
jeepney==0.7.1
Jinja2==3.1.3
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.2
jupyter_client==8.6.0
jupyter_core==5.7.1
jupyter_server==2.12.5
jupyter_server_terminals==0.5.2
jupyterlab==4.1.1
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.2
jupyterlab_widgets==3.0.10
keyring==23.5.0
kiwisolver==1.4.5
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
MarkupSafe==2.1.5
matplotlib==3.8.2
matplotlib-inline==0.1.6
mistune==3.0.2
more-itertools==8.10.0
mpmath==1.3.0
nbclient==0.9.0
nbconvert==7.16.0
nbformat==5.9.2
nest-asyncio==1.6.0
netifaces==0.11.0
networkx==3.2.1
nir==1.0.1
nirtorch==1.0
notebook==7.1.0
notebook_shim==0.2.3
numpy==1.26.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.7.0.84
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.19.3
nvidia-nvtx-cu11==11.8.86
oauthlib==3.2.0
overrides==7.7.0
packaging==23.2
pandas==2.2.1
pandocfilters==1.5.1
parso==0.8.3
pexpect==4.9.0
pillow==10.2.0
platformdirs==4.2.0
prometheus-client==0.19.0
prompt-toolkit==3.0.43
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.17.2
PyGObject==3.42.1
PyJWT==2.3.0
pyparsing==2.4.7
python-apt==2.4.0+ubuntu3
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2024.1
PyYAML==5.4.1
pyzmq==25.1.2
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.18.0
SecretStorage==3.3.1
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
snntorch==0.7.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
systemd-python==234
terminado==0.18.0
tinycss2==1.2.1
tomli==2.0.1
torch==2.2.0+cu118
torchaudio==2.2.0+cu118
torchvision==0.17.0+cu118
tornado==6.4
traitlets==5.14.1
triton==2.2.0
types-python-dateutil==2.8.19.20240106
typing_extensions==4.9.0
tzdata==2024.1
ubuntu-advantage-tools==8001
ufw==0.36.1
unattended-upgrades==0.1
uri-template==1.3.0
urllib3==2.2.0
wadllib==1.3.6
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
widgetsnbextension==4.0.10
zipp==1.0.0