Is it possible to write a custom autograd function that returns a list of tensors? For context, I’m trying to manually write the backward pass for all_gather
.
This code:
class AllGatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, dim, reduce_dtype):
ctx.reduce_dtype = reduce_dtype
ctx.dim = dim
output_list = list(torch.empty_like(tensor) for _ in range(_CONTEXT_PARALLEL_GROUP_SIZE))
dist.all_gather(output_list, tensor, _CONTEXT_PARALLEL_GROUP)
return output_list
# return torch.cat(output_list, dim=dim)
fails because I can’t return a list of tensors.