DDP Second backward() accumulate the wrong gradient

Update:

I eventually managed to get around this by using autograd.grad for the first backward because autograd.grad won’t trigger DDP gradient synchronization. I attached my code below for people who have the same problem.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    z = torch.ones((2, 1)).to(rank)
    loss = D(G(z)).mean()

    grad_d = autograd.grad(loss, d_params, create_graph=True)
    gradvec_d = torch.cat([g.contiguous().view(-1) for g in grad_d])
    autograd.backward(gradvec_d,
                      grad_tensors=torch.ones_like(gradvec_d),
                      inputs=g_params)   # compute d{torch.dot(gradvec_d, vec)} / d{G}
    hvp = collect_grad(g_params)  # gather results

    print('Hessian vector product: ', hvp)
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')
2 Likes