DDP Second backward() accumulate the wrong gradient

Hello folks, I was trying to compute Hessian vector product by using backward twice with DDP. However, the second backward doesn’t seem to gather gradient as desired. The same code worked in single GPU but got wrong value in multi-GPU case (in my case 2 GPUs). This is also autograd related issue maybe also invite @albanD to take a look here.

Results on single GPU, which is correct :slight_smile: :

Running on rank 0
Hessian vector product:  tensor([1., 1., 1., 1.], device='cuda:0', grad_fn=<CatBackward>)
Done!

Results on two GPUs, which is wrong :frowning:

Running on rank 0
Running on rank 1
Hessian vector product:  tensor([0., 0., 0., 0.], device='cuda:0', grad_fn=<CatBackward>)
Hessian vector product:  tensor([0., 0., 0., 0.], device='cuda:1', grad_fn=<CatBackward>)
Done!

Here is the minimal repro code. I initialized the networks with constant so that everything is deterministic here.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def zero_grad(params):
    '''
    Clean the gradient of each parameter
    '''
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()


def collect_grad(params):
    '''
    Collect grads of parameters and concatenate them into a vector.
    If grad is None, it will be filled with zeros
    :param params: list of parameters
    :return: vector
    '''
    grad_list = []
    for p in params:
        if p.grad is not None:
            grad_list.append(p.grad.contiguous().view(-1))
            del p.grad
        else:
            # replace None with zeros
            grad_list.append(torch.zeros_like(p).view(-1))
    return torch.cat(grad_list)


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    z = torch.ones((2, 1)).to(rank)
    loss = D(G(z)).mean()

    loss.backward(create_graph=True)
    gradvec_d = collect_grad(d_params)   # d{loss} / d{D}

    zero_grad(g_params)  # clean the grad before backward
    autograd.backward(gradvec_d,
                      grad_tensors=torch.ones_like(gradvec_d),
                      inputs=g_params)   # compute d{torch.dot(gradvec_d, vec)} / d{G}
    hvp = collect_grad(g_params)  # gather results

    print('Hessian vector product: ', hvp)
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')

I just realize that reason is probably DPP does gradient gathering operation across all devices every time when backward() is called, which does not define grad_fn and break the computation graph such that gradient will not be accumulated in the second backward().
Do you folks have any idea about how to get around with this? Any thought will be appreciated.

Update:

I eventually managed to get around this by using autograd.grad for the first backward because autograd.grad won’t trigger DDP gradient synchronization. I attached my code below for people who have the same problem.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    z = torch.ones((2, 1)).to(rank)
    loss = D(G(z)).mean()

    grad_d = autograd.grad(loss, d_params, create_graph=True)
    gradvec_d = torch.cat([g.contiguous().view(-1) for g in grad_d])
    autograd.backward(gradvec_d,
                      grad_tensors=torch.ones_like(gradvec_d),
                      inputs=g_params)   # compute d{torch.dot(gradvec_d, vec)} / d{G}
    hvp = collect_grad(g_params)  # gather results

    print('Hessian vector product: ', hvp)
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')
2 Likes

Hey!

Interesting investigation.
Could you actually open an issue on github about this? Asking to add support natively (or raise a nice error if it doesn’t) and document workarounds.

Thanks!

1 Like

Sure! That would be great.