Is there a autograd function for all reduce operation?
Thanks!
Is there a autograd function for all reduce operation?
Thanks!
I figured it out:
class AllReduce(Function):
def forward(ctx, *inputs):
outputs = nccl_all_reduce(list(inputs))
return tuple(outputs)
def backward(ctx, *gradOutputs):
gradInputs = nccl_all_reduce(list(gradOutputs))
return tuple(gradInputs)