Multiple-GPU Error - Data Parallel

Thanks @ptrblck. This reproduces the error for me:

import os
from pytorchcv.model_provider import get_model as ptcv_get_model
import torch
import types
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cifar_loc = '/disk/scratch/s1874193/datasets/cifar'

net = ptcv_get_model("densenet40_k12_cifar10", root = '/home/s1874193/Distillation/xdistill/pre_trained_models', pretrained=True)
def my_forward(self, x):
    activations = []
    for module in self.features._modules.values():
        x = module(x)
        activations.append(x)
    x = x.view(x.size(0), -1)
    x = self.output(x)
    return x, activations

net.forward = types.MethodType(my_forward, net)

if torch.cuda.device_count() > 1:
    net = nn.DataParallel(net, device_ids=[0,1,2,3])
net.to(device)
net.eval()

x = torch.randn(4, 3, 32, 32)
out, act = net(x)


I’m not doing anything with the forward() method other that what you can see here. I think it’s somehow related to how I’m using CIFAR, as I didn’t get the error just doing

x = torch.randn(1, 3, 32, 32)
out, activations = net(x)

Thanks!