Is there any way to seperate the replicating and the collecting of gradients of DataParallel module?

I update the parameters of the designed model after several iterative forward and backward calls, which means that the DataParallel does not need to replicate self.module in every forward path, and it only needs to accumulate the gradients on the other gpus when I am going to update the parameters with a call to step() of the optimizer.

So, is there any way to seperate the replicating and the collecting of gradients of DataParallel? It greatly hurts the performance in my implementation(

I tried to make this through the following code:

class DataParallelModel(DataParallel):

def __init__(self, module, device_ids=None, output_device=None, dim=0, host_control=True):
    super(DataParallelModel, self).__init__(module, device_ids, output_device, dim)

if host_control:
	self.nets = self.update_replicates()
	self.nets = None

def forward(self, *inputs, **kwargs):
    if not self.device_ids:
        return self.module(*inputs, **kwargs)
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    if len(self.device_ids) == 1:
        return self.module(*inputs[0], **kwargs[0])
    if self.nets is None:
            replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
	replicas = self.nets[:len(inputs)]
    outputs = self.parallel_apply(replicas, inputs, kwargs)
    return self.gather(outputs, self.output_device)

def update_replicates(self):
    return self.replicate(self.module, self.device_ids)

But an assertation failure was throwed by autograd.

If you split your input on several GPUs, you have to replicate the module on each GPU.
I wrote some details here, maybe it can help you.

Thank you.
I have read your blog, and I want to make the first way mentioned in your blog more efficiently on multiple GPUs. I fixed it with this now.