All Reduce Autograd Function

Is there a autograd function for all reduce operation?

Thanks!

I figured it out:

class AllReduce(Function):
    def forward(ctx, *inputs):
        outputs = nccl_all_reduce(list(inputs))
        return tuple(outputs)

    def backward(ctx, *gradOutputs):
        gradInputs = nccl_all_reduce(list(gradOutputs))
        return tuple(gradInputs)