Backward function of "torch.nn.parallel._functions.Scatter" is never been called?

The forward function of torch.nn.parallel.DataParallel calls its member function “scatter” to replicate the input data into all of the devices:

class DataParallel(Module):
    def forward(self, replicate_model=True, gather_grad=True, *inputs, **kwargs):
        inputs, kwargs = self.scatter(gather_grad, inputs, kwargs, self.device_ids)
        ...

And the scatter function was finally implemented by the module torch.nn.parallel._functions.Scatter like this:

class Scatter(Function):
    @staticmethod
    def forward(ctx, target_gpus, chunk_sizes, dim, input):
       ...

    @staticmethod
    def backward(ctx, *grad_output):
        return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)

So I think at the backward process of DataParallel, the backward function of module Scatter should be called. And in this way it could gather all of the gradients distributed from every device.

But when I try to do something in the backward function of Scatter, just like print a line:

class Scatter(Function):
    @staticmethod
    def backward(ctx, *grad_output):
        print("try to print something.")
        return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)

I always got nothing printed. It seems that Scatter::backward() function has never been called.

I wonder why the function is not been called and then how DataParallel gather the gradients from all of the devices? Is there anything wrong with my testing?

Thanks very much!

I found it is my misunderstood.
It is Broadcast::backward() rather than Scatter::backward() gathers the grads from all of the devices.