How can I return a list of tensors in a Pytorch autograd function?

Is it possible to write a custom autograd function that returns a list of tensors? For context, I’m trying to manually write the backward pass for all_gather.

This code:

class AllGatherFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor: torch.Tensor, dim, reduce_dtype):
        ctx.reduce_dtype = reduce_dtype
        ctx.dim = dim

        output_list = list(torch.empty_like(tensor) for _ in range(_CONTEXT_PARALLEL_GROUP_SIZE))
        dist.all_gather(output_list, tensor, _CONTEXT_PARALLEL_GROUP)
        return output_list
        # return torch.cat(output_list, dim=dim)

fails because I can’t return a list of tensors.

Hi Vedant!

Yes. For example:

>>> import torch
>>> torch.__version__
'2.4.0'
>>> class Func (torch.autograd.Function):
...     @staticmethod
...     def forward (ctx, t):
...         return [t**2, t**3]
...
>>> Func().apply (torch.arange (3.))
[tensor([0., 1., 4.]), tensor([0., 1., 8.])]

Best.

K. Frank