Shown below is the error I receive when attempting to run this setup code block in Jupyter:
RuntimeError Traceback (most recent call last)
Cell In[5], line 136
131 data = data.to(device)
132 targets = targets.to(device)
--> 136 spk_rec, mem_rec = net(data.view(batch_size, -1))
137 # print(mem_rec.size())
138
139 # initialize the total loss value
140 loss_val = torch.zeros((1), dtype=dtype, device=device)
File c:\Users\brooks\AppData\Local\anaconda3\envs\snntorch\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File c:\Users\brooks\AppData\Local\anaconda3\envs\snntorch\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
...
--> 105 mem_shift = mem - self.threshold
106 reset = self.spike_grad(mem_shift).clone().detach()
108 return reset
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
This is my setup code block that triggers this error. I have attempted checking the variables that are tensors and manually setting them all to cuda and it has not worked.
# Main Setup
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import itertools
# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'
dtype = torch.float
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device("cuda")
# Define a transform
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0,), (1,))])
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10
# Temporal Dynamics
num_steps = 25
beta = 0.70
V1 = 0.5 # shared recurrent connection
V2 = torch.rand(num_outputs) # unshared recurrent connections
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
# initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
# Default RLeaky Layer where recurrent connections
# are initialized using PyTorch defaults in nn.Linear.
self.lif1 = snn.RLeaky(beta=beta,
linear_features=num_hidden)
self.fc2 = nn.Linear(num_hidden, num_outputs)
# each neuron has a single connection back to itself
# where the output spike is scaled by V.
# For `all_to_all = False`, V can be shared between
# neurons (e.g., V1) or unique / unshared between
# neurons (e.g., V2).
# V is learnable by default.
self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1)
def forward(self, x):
# Initialize hidden states at t=0
spk1, mem1 = self.lif1.init_rleaky()
spk2, mem2 = self.lif2.init_rleaky()
# Record output layer spikes and membrane
spk2_rec = []
mem2_rec = []
# time-loop
for step in range(num_steps):
cur1 = self.fc1(x)
spk1, mem1 = self.lif1(cur1, spk1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, spk2, mem2)
spk2_rec.append(spk2)
mem2_rec.append(mem2)
# convert lists to tensors
spk2_rec = torch.stack(spk2_rec)
mem2_rec = torch.stack(mem2_rec)
return spk2_rec, mem2_rec
# Load the network onto CUDA if available
net = Net().to(device)
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target
def print_batch_accuracy(data, targets, train=False):
output, _ = net(data.view(batch_size, -1))
_, idx = output.sum(dim=0).max(1)
acc = np.mean((targets == idx).detach().cpu().numpy())
if train:
print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
else:
print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")
def train_printer(
data, targets, epoch,
counter, iter_counter,
loss_hist, test_loss_hist, test_data, test_targets):
print(f"Epoch {epoch}, Iteration {iter_counter}")
print(f"Train Set Loss: {loss_hist[counter]:.2f}")
print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
print_batch_accuracy(data, targets, train=True)
print_batch_accuracy(test_data, test_targets, train=False)
print("\n")
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)
spk_rec, mem_rec = net(data.view(batch_size, -1))
# print(mem_rec.size())
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)
# sum loss at every step
for step in range(num_steps):
loss_val += loss(mem_rec[step], targets)
# print(f"Training loss: {loss_val.item():.3f}")
# print_batch_accuracy(data, targets, train=True)
# clear previously stored gradients
optimizer.zero_grad()
# calculate the gradients
loss_val.backward()
# weight update
optimizer.step()
# calculate new network outputs using the same data
spk_rec, mem_rec = net(data.view(batch_size, -1))
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)
# sum loss at every step
for step in range(num_steps):
loss_val += loss(mem_rec[step], targets)
# print(f"Training loss: {loss_val.item():.3f}")
# print_batch_accuracy(data, targets, train=True)