Hi, I’m trying to implement a customized DataParallel, where I want to manually reduce gradients from replicas in multiple GPUs. The reason I’m doing it is that I want each GPU to accumulate gradients for several iterations before doing one reduce gradients operation across multi-GPUs, hence reducing the communication overhead. When using 2 GPUs it seems to work. However, when using 4 GPUs after some iteration the program simply crashes without any error message. I think it’s some lower level C code crashes. Do you have any idea why? Here is my code:
`class DataParallelAccumulation(DataParallel):
def init(self, module, device_ids=None, output_device=None, dim=0):
super().init(module, device_ids=device_ids, output_device=output_device, dim=dim)
if len(self.device_ids) > 1:
self.replicas = self.replicate(self.module, self.device_ids, detach=True)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
outputs = self.parallel_apply(self.replicas[:len(inputs)], inputs, kwargs)
return outputs
def reduce_grads(self):
if len(self.device_ids) > 1:
for parameters in zip(self.module.parameters(), *[r.parameters() for r in self.replicas]):
destination_device = parameters[0].get_device()
parameters[0].grad = (comm.reduce_add([p.grad for p in parameters[1:]],
destination=destination_device))
def synchronize(self):
if len(self.device_ids) > 1:
self.replicas = self.replicate(self.module, self.device_ids, detach=True)
def replicate(self, module, device_ids, detach=False):
replicas = replicate(module, device_ids, detach=detach)
return replicas
`