Autograd of scatter

I recently want to change Scatter. The original code as below, Scatter is the subclass of Function. But when I exec my code, I find that the backward function in Scatter not exec, whether I use DataParallel or DistributedDataParallel.

class Scatter(Function):

    @staticmethod
    def forward(ctx, target_gpus, chunk_sizes, dim, input):
        target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
        ctx.dim = dim
        ctx.input_device = input.get_device() if input.is_cuda else -1
        streams = None
        if ctx.input_device == -1:
            # Perform CPU to GPU copies in a background stream
            streams = [_get_stream(device) for device in target_gpus]
        outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
        # Synchronize with the copy stream
        if streams is not None:
            for i, output in enumerate(outputs):
                with torch.cuda.device(target_gpus[i]):
                    main_stream = torch.cuda.current_stream()
                    main_stream.wait_stream(streams[i])
                    output.record_stream(main_stream)
        return outputs

    @staticmethod
    def backward(ctx, *grad_output):
        return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)

Questions:

  1. The backward function in Scatter, when to trigger execution?
  2. Pytorch how to compute gradient on multi gpus(use DataParallel) and multi node(DistributedDataParallel)?

My code first exec model’s forward, then use model’s output calculate loss on GPU0. Such as

import torch.nn as nn
from torchvision.models import resnet50

def classify_loss_fn(output, label):
  loss = torch.nn.CrossEntropyLoss()(output, label.long())
  return loss

model = resnet50()
model = nn.DataParallel(model, device_ids=[1, 2, 3, 4]) # or
model = nn.parallel.DistributedDataParallel(model, device_ids=[1, 2, 3, 4],
                                            find_unused_parameters=True)

for (img, label) in dataloader:
  output = model(img)
  loss = classify_loss_fn(output, label)
  loss.backward()