Multi GPU test dropout

Hi there,

sorry if this question has been asked before but I could not find something on this particular question:

I have a batch of data x (e.g. Nx1x28x28) which I want to evaluate on my model net using M forward passes but with different random numbers in each run, e.g. for dropout during test time, and average them at the end. What is the most efficient way to parallelize the forward passes on multiple GPUs? I could do something like

gpu_count = torch.cuda.device_count()
net = nn.DataParallel(net, list(range(gpu_count)))
x = x.repeat(gpu_count,1,1,1)
x = x.cuda()

outputs = []
for _ in range(M):

But this increases the memory demand on GPU 0 by a factor of gpu_count as the tensor x needs to be copied to one of the GPUs first and is then distributed in chunks to the other ones. Can this be done more efficiently?
Also: Do the individual instances on each GPU have different random seeds for the forward passes?

Thanks for your help!

EDIT: Also it is probably not that efficient to distribute the repeated x for every of the M forward passes. It would be nice to have the x on each GPU individually for all the M iterations until the next batch.

I solved it myself by modifying the code for the DataParallel wrapper. Now the predicted tensor is broadcasted to the different gpus instead of scattered. The sampled outputs are then accumulated on each gpu and gathered all at once when enough samples were reached. If you have memory constraints you maybe want to change the code to gather the outputs after each sampling:

class ParallelSamplingWrapper(nn.Module):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]
        self.dim = dim
        self.module = module
        self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        if len(self.device_ids) == 1:
    def forward(self, input, **kwargs):
        if not self.device_ids:
            return self.module(input, **kwargs)
        inputs, kwargs = self.broadcast(input, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs, **kwargs)
        replicas = self.replicate(self.module, self.device_ids)
        return self.sample(replicas, inputs, kwargs)
    def sample(self, replicas, inputs, kwargs):
        if 'samples' in kwargs:
            samples = kwargs['samples']
            samples = len(self.device_ids)
        outputs = []
        for _ in range(math.ceil(samples/len(self.device_ids))):
            out = self.parallel_apply(replicas, inputs, kwargs)
            out = [x.unsqueeze(0) for x in out]
            outputs += out     # gather here if you have memory constraints
        outputs = self.gather(outputs, self.output_device)[:samples]
        return outputs
    def broadcast(self, input, kwargs, device_ids):
        inputs = comm.broadcast(input, device_ids) if input is not None else []
        kwargs = scatter(kwargs, device_ids) if kwargs else []
        return inputs, kwargs
    def replicate(self, module, device_ids):
        return replicate(module, device_ids)
    def parallel_apply(self, replicas, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
    def gather(self, outputs, output_device):
        return comm.gather(outputs, dim=self.dim, destination=output_device)