Error while applying torch.nn.DataParallel to custom network

 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torchvision import datasets, transforms
 from import DataLoader

 import copy

 devices    = [0,1]
 batch_size = 64
 sim_type   = 'snn'
 timesteps  = 100

 device = 'cuda:' + str(devices[0])

 def test(model, epoch, test_loader):
     print('Testing model..')

     total = 0
     with torch.no_grad():

         for i in range(100):
             # looping over dataset (start)
             for batch_idx, (data, target) in enumerate(test_loader):

                 if sim_type == 'snn':
                     data = data.repeat(timesteps, 1, 1, 1, 1)
                     data = data.swapaxes(0, 1)
                     B,T,C,H,W = data.size()
                     B,C,H,W   = data.size()

                 data, target =,
                 if batch_idx == 0:
                     print('data: {}'.format(data.size()))

                 output  = model(data)
                 loss    = F.cross_entropy(output, target)
                 pred    = output.max(1,keepdim=True)[1]
                 correct = pred.eq(
                 total   += B
                 top1    = correct / total

                 print('Batch: {}, Accuracy: {}/{}({:.4f}), Loss: {:.4f}'
                     .format( batch_idx, correct.item(), B, top1, loss,))
             # looping over dataset (end)

 def load_cifar100():
     normalize = transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
     transform = transforms.Compose([transforms.ToTensor(), normalize])
     testset   = datasets.CIFAR100(root='./Datasets/cifar_data', \
             train=False, download=False, transform=transform)
     return testset

 class ANN(nn.Module):
     def __init__(self):

         self.ann_dict = nn.ModuleDict({
             'conv1': nn.Conv2d(3, 64, 3, padding=1, stride=1, bias=False),
             'relu1': nn.ReLU(inplace=True),
             'avg1' : nn.AvgPool2d(2, 2),

             'conv2': nn.Conv2d(64, 128, 3, padding=1, stride=1, bias=False),
             'relu2': nn.ReLU(inplace=True),
             'avg2' : nn.AvgPool2d(2, 2),

             'flat3': nn.Flatten(),
             'lin4' : nn.Linear(8192, 100, bias=False),

     def forward(self, x):
         out_prev = x
         for k, v in self.ann_dict.items():
             out_prev = self.ann_dict[k](out_prev)
         return out_prev

 class SNN(nn.Module):
     def __init__(self, ann, device, in_sz):
         super(SNN, self).__init__()

         num_relus = 0
         for k, v in ann.named_modules():
             if isinstance(v, nn.ReLU):
                 num_relus += 1

         self.device   = device
         thresholds    = torch.ones(num_relus, device=device)
         self.snn_cell = SNNCell(ann, thresholds, device=device, in_sz=in_sz)

     def reset_vmems(self):

     def get_vmems(self):
         return self.snn_cell.get_vmems()

     def forward(self, x):
         vmems = self.get_vmems()
         return self.snn_cell(x, *vmems)

 class SNNCell(nn.Module):
     def __init__(self, ann, thresholds, device, in_sz):
         super(SNNCell, self).__init__()

         self.device   = device
         self.in_sz    = in_sz
         self.snn_dict = self.convert_to_snn(ann, thresholds)

     def convert_to_snn(self, ann, thr):
         snn_dict   = copy.deepcopy(ann.ann_dict)
         relu_idx   = 0
         in_sz      = self.in_sz
         for k, v in snn_dict.items():
             out_sz = v.cpu()(torch.zeros(in_sz)).size()
             if isinstance(v, nn.ReLU):
                 snn_dict[k] = SpikeRelu(thr[relu_idx], sz=in_sz, device=self.device,)
                 relu_idx += 1
             in_sz = out_sz
         return snn_dict

     def reset_vmems(self):
         for k, v in self.snn_dict.items():
             if isinstance(v, SpikeRelu):
     def get_vmems(self):
         vmems = []
         for k, v in self.snn_dict.items():
             if isinstance(v, SpikeRelu):
         return vmems

     def forward(self, x, *mem):
         T = x.size(1)

         L       = len(mem)
         mem_in  = list(mem[:])
         mem_out = [None] * L

         for t in range(T):
             spk_layer_idx  = 0
             outprev        = x[:,t]
             for name, layer in self.snn_dict.items():
                 if isinstance(layer, SpikeRelu):
                     outprev, mem_out[spk_layer_idx] = layer(outprev, mem_in[spk_layer_idx])
                     spk_layer_idx += 1
                     outprev = layer(outprev)
             mem_in = mem_out[:]
         return outprev

 class SpikeRelu(nn.Module):
     def __init__(self, vth, sz, device='cuda:0'):
          super(SpikeRelu, self).__init__()

          self.threshold = vth
        = sz
          self.device    = device

     def reset_vmem(self):
         self.vmem = torch.zeros(*, device=self.device, requires_grad=True)

     def forward(self, x, mem_in):
         mem_thr = (mem_in / self.threshold) - 1.0
         op_spk  = torch.zeros_like(mem_thr)
         op_spk[mem_thr > 0] = 1.0
         rst     = self.threshold * (mem_thr > 0).float()
         mem_out = mem_in + x - rst
         return op_spk, mem_out

     def __repr__(self):
         return 'SpikeRelu(v_th : {:.3f}, vmem : {})'.format(self.threshold, torch.sum(self.vmem))

Hello everyone,
I am trying to apply torch.nn.DataParallel to a custom module (class SNN) whose simplified definition is provided above. I perform inference on this network using the following code:

 test_set    = load_cifar100()
 test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
 model = ANN()

 print('devices ', devices, 'device ', device)
 in_sz = (1,3,32,32)
 if sim_type == 'snn':
     model = SNN(model, device, in_sz, )

 model = torch.nn.DataParallel(model, device_ids=devices)
 model =

 test(model, 0, test_loader)

This keeps giving the following error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

I believe the error is due to the way I am initializing my vmem in class SpikeRelu but I am not sure how to fix it. Please help!!

Don’t pass the device argument to SkipeRelu, but use the .device attribute from e.g. the incoming activation tensor, if needed.
E.g. you could use:

     def reset_vmem(self, device):
         self.vmem = torch.zeros(*, device=device, requires_grad=True)

and call it via self.reset_vmem(x.device) where x would be the input tensor.

Thanks, @ptrblck, your solution worked perfectly!