I recently want to change Scatter
. The original code as below, Scatter
is the subclass of Function
. But when I exec my code, I find that the backward
function in Scatter
not exec, whether I use DataParallel
or DistributedDataParallel
.
class Scatter(Function):
@staticmethod
def forward(ctx, target_gpus, chunk_sizes, dim, input):
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
ctx.dim = dim
ctx.input_device = input.get_device() if input.is_cuda else -1
streams = None
if ctx.input_device == -1:
# Perform CPU to GPU copies in a background stream
streams = [_get_stream(device) for device in target_gpus]
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
# Synchronize with the copy stream
if streams is not None:
for i, output in enumerate(outputs):
with torch.cuda.device(target_gpus[i]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[i])
output.record_stream(main_stream)
return outputs
@staticmethod
def backward(ctx, *grad_output):
return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
Questions:
- The
backward
function inScatter
, when to trigger execution? - Pytorch how to compute gradient on multi gpus(use DataParallel) and multi node(DistributedDataParallel)?