The forward function of
torch.nn.parallel.DataParallel calls its member function “scatter” to replicate the input data into all of the devices:
class DataParallel(Module): def forward(self, replicate_model=True, gather_grad=True, *inputs, **kwargs): inputs, kwargs = self.scatter(gather_grad, inputs, kwargs, self.device_ids) ...
And the scatter function was finally implemented by the module
torch.nn.parallel._functions.Scatter like this:
class Scatter(Function): @staticmethod def forward(ctx, target_gpus, chunk_sizes, dim, input): ... @staticmethod def backward(ctx, *grad_output): return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
So I think at the backward process of
backward function of module
Scatter should be called. And in this way it could gather all of the gradients distributed from every device.
But when I try to do something in the backward function of
Scatter, just like print a line:
class Scatter(Function): @staticmethod def backward(ctx, *grad_output): print("try to print something.") return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
I always got nothing printed. It seems that
Scatter::backward() function has never been called.
I wonder why the function is not been called and then how
DataParallel gather the gradients from all of the devices? Is there anything wrong with my testing?
Thanks very much!