import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data.dataloader import DataLoader
import copy
devices = [0,1]
batch_size = 64
sim_type = 'snn'
timesteps = 100
device = 'cuda:' + str(devices[0])
def test(model, epoch, test_loader):
print('Testing model..')
total = 0
with torch.no_grad():
model.eval()
for i in range(100):
# looping over dataset (start)
for batch_idx, (data, target) in enumerate(test_loader):
if sim_type == 'snn':
data = data.repeat(timesteps, 1, 1, 1, 1)
data = data.swapaxes(0, 1)
B,T,C,H,W = data.size()
else:
B,C,H,W = data.size()
data, target = data.to(device), target.to(device)
if batch_idx == 0:
print('data: {}'.format(data.size()))
output = model(data)
loss = F.cross_entropy(output, target)
pred = output.max(1,keepdim=True)[1]
correct = pred.eq(target.data.view_as(pred)).cpu().sum()
total += B
top1 = correct / total
print('Batch: {}, Accuracy: {}/{}({:.4f}), Loss: {:.4f}'
.format( batch_idx, correct.item(), B, top1, loss,))
# looping over dataset (end)
def load_cifar100():
normalize = transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])
testset = datasets.CIFAR100(root='./Datasets/cifar_data', \
train=False, download=False, transform=transform)
return testset
class ANN(nn.Module):
def __init__(self):
super().__init__()
self.ann_dict = nn.ModuleDict({
'conv1': nn.Conv2d(3, 64, 3, padding=1, stride=1, bias=False),
'relu1': nn.ReLU(inplace=True),
'avg1' : nn.AvgPool2d(2, 2),
'conv2': nn.Conv2d(64, 128, 3, padding=1, stride=1, bias=False),
'relu2': nn.ReLU(inplace=True),
'avg2' : nn.AvgPool2d(2, 2),
'flat3': nn.Flatten(),
'lin4' : nn.Linear(8192, 100, bias=False),
})
def forward(self, x):
out_prev = x
for k, v in self.ann_dict.items():
out_prev = self.ann_dict[k](out_prev)
return out_prev
class SNN(nn.Module):
def __init__(self, ann, device, in_sz):
super(SNN, self).__init__()
num_relus = 0
for k, v in ann.named_modules():
if isinstance(v, nn.ReLU):
num_relus += 1
self.device = device
thresholds = torch.ones(num_relus, device=device)
self.snn_cell = SNNCell(ann, thresholds, device=device, in_sz=in_sz)
def reset_vmems(self):
self.snn_cell.reset_vmems()
def get_vmems(self):
return self.snn_cell.get_vmems()
def forward(self, x):
self.reset_vmems()
vmems = self.get_vmems()
return self.snn_cell(x, *vmems)
class SNNCell(nn.Module):
def __init__(self, ann, thresholds, device, in_sz):
super(SNNCell, self).__init__()
self.device = device
self.in_sz = in_sz
self.snn_dict = self.convert_to_snn(ann, thresholds)
def convert_to_snn(self, ann, thr):
snn_dict = copy.deepcopy(ann.ann_dict)
relu_idx = 0
in_sz = self.in_sz
for k, v in snn_dict.items():
out_sz = v.cpu()(torch.zeros(in_sz)).size()
if isinstance(v, nn.ReLU):
snn_dict[k] = SpikeRelu(thr[relu_idx], sz=in_sz, device=self.device,)
relu_idx += 1
in_sz = out_sz
return snn_dict
def reset_vmems(self):
for k, v in self.snn_dict.items():
if isinstance(v, SpikeRelu):
self.snn_dict[k].reset_vmem()
def get_vmems(self):
vmems = []
for k, v in self.snn_dict.items():
if isinstance(v, SpikeRelu):
vmems.append(v.vmem.clone())
return vmems
def forward(self, x, *mem):
T = x.size(1)
L = len(mem)
mem_in = list(mem[:])
mem_out = [None] * L
for t in range(T):
spk_layer_idx = 0
outprev = x[:,t]
for name, layer in self.snn_dict.items():
if isinstance(layer, SpikeRelu):
outprev, mem_out[spk_layer_idx] = layer(outprev, mem_in[spk_layer_idx])
spk_layer_idx += 1
else:
outprev = layer(outprev)
mem_in = mem_out[:]
return outprev
class SpikeRelu(nn.Module):
def __init__(self, vth, sz, device='cuda:0'):
super(SpikeRelu, self).__init__()
self.threshold = vth
self.sz = sz
self.device = device
self.reset_vmem()
def reset_vmem(self):
self.vmem = torch.zeros(*self.sz, device=self.device, requires_grad=True)
def forward(self, x, mem_in):
mem_thr = (mem_in / self.threshold) - 1.0
op_spk = torch.zeros_like(mem_thr)
op_spk[mem_thr > 0] = 1.0
rst = self.threshold * (mem_thr > 0).float()
mem_out = mem_in + x - rst
return op_spk, mem_out
def __repr__(self):
return 'SpikeRelu(v_th : {:.3f}, vmem : {})'.format(self.threshold, torch.sum(self.vmem))
Hello everyone,
I am trying to apply torch.nn.DataParallel
to a custom module (class SNN
) whose simplified definition is provided above. I perform inference on this network using the following code:
test_set = load_cifar100()
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
model = ANN()
print('devices ', devices, 'device ', device)
in_sz = (1,3,32,32)
if sim_type == 'snn':
model = SNN(model, device, in_sz, )
model = torch.nn.DataParallel(model, device_ids=devices)
model = model.to(device)
print(model)
test(model, 0, test_loader)
This keeps giving the following error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
I believe the error is due to the way I am initializing my vmem
in class SpikeRelu
but I am not sure how to fix it. Please help!!