Custom autograd function breaking computation graph

I have the following autograd function that causes the tensors to lost their grad_fn:

class Combine(torch.autograd.Function):
@staticmethod
def forward(ctx, tensors, machine_mapping, dim):
org_devices =
tensors_on_mm =

    for tensor in tensors:
        org_devices.append(tensor.device)
        tensor = tensor.to(machine_mapping[0])
        tensors_on_mm.append(tensor)
      
    ctx.org_devices = org_devices
    ctx.dim = dim
    
    res = torch.cat(tensors_on_mm, dim)

    return res

@staticmethod
def backward(ctx, grad):
    chunks = torch.chunk(grad, len(ctx.org_devices), ctx.dim)

    grads = []
    for machine, chunk in zip(ctx.org_devices, chunks):
        chunk = chunk.to(machine)
        grads.append(chunk)
    
    return tuple(grads), None, None

Just some context, this function is utilized in a distributed training setup where tensors that are on different GPUs can be combined together.

My understanding is that this issue happens because of the tensor.to(machine_mapping[0]) line. However, whenever I implement this same functionality outside of the custom.autograd function, it works fine. I am curious as to why such an operation is causing an issue and is there anyway to work around it. I do need to stick to the custom function because, as mentioned earlier, this is a distributed training setup that requires tensors to be moved to and from devices in their forward and backward pass.

Hi,
Custom autograd.Function does not look for tensors that require grad within list inputs.
You might be interested in something like pytreeify decorators · Issue #96337 · pytorch/pytorch · GitHub

1 Like

Hi kr!

As @soulitzer notes, autograd doesn’t know to drill down into the tuple to discover that
is contains tensors whose .grads should be tracked. You need to pass your custom
autograd function explicit tensor arguments so that autograd will see them.

You can use python’s *args syntax to pass in a variable number of such arguments
(as I presume is your use case).

Here is a script containing a tweaked version of your custom function:

import torch
print (torch.__version__)

class Combine(torch.autograd.Function):
    @staticmethod
    # def forward(ctx, tensors, machine_mapping, dim):
    def forward(ctx, machine_mapping, dim, *tensors):
        org_devices = []
        tensors_on_mm = []
        
        for tensor in tensors:
            org_devices.append(tensor.device)
            tensor = tensor.to(machine_mapping[0])
            tensors_on_mm.append(tensor)
        
        ctx.org_devices = org_devices
        ctx.dim = dim
        
        res = torch.cat(tensors_on_mm, dim)
        
        return res
    
    @staticmethod
    def backward(ctx, grad):
        chunks = torch.chunk(grad, len(ctx.org_devices), ctx.dim)
        
        grads = []
        for machine, chunk in zip(ctx.org_devices, chunks):
            chunk = chunk.to(machine)
            grads.append(chunk)
        
        # return tuple(grads), None, None
        return  None, None, *grads

t1 = torch.zeros (2, 3, requires_grad = True, device = 'cpu')   # two different "machines"
t2 = torch.ones (2, 3, requires_grad = True, device = 'cuda')   # two different "machines"

machine_mapping = ['cpu']
dim = 0

res = Combine.apply (machine_mapping, dim, *(t1, t2))           # call with *args syntax

print ('t1 = ...')
print (t1)
print ('t2 = ...')
print (t2)
print ('res = ...')
print (res)
print ('res.device:', res.device)

res.sum().backward()

print ('t1.grad = ...')
print (t1.grad)
print ('t2.grad = ...')
print (t2.grad)

And here is its output:

2.6.0+cu126
t1 = ...
tensor([[0., 0., 0.],
        [0., 0., 0.]], requires_grad=True)
t2 = ...
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:0', requires_grad=True)
res = ...
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.]], grad_fn=<CombineBackward>)
res.device: cpu
t1.grad = ...
tensor([[1., 1., 1.],
        [1., 1., 1.]])
t2.grad = ...
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')

Best.

K. Frank

1 Like