Autograd.backward() doesn't trigger reduction for Jacobian vector product

Hi everyone! I found that autograd.backward() doesn’t trigger reduction when I tried to compute certain Jacobian (w.r.t. D network) vector product in the following toy GAN case. However if I use the same approach to compute Jacobian (w.r.t. G network) vector product, DDP works perfectly with autograd.backward(). This is a follow-up post on my previous post where I found a way to compute Jacobian vector product with DDP(ddp-second-backward-accumulate-the-wrong-gradient).

Compute Jacobian vector product w.r.t D networks: as shown below, the results I got from backward is different across different device.

Running on rank 0
Running on rank 1
Hessian vector product of d param: tensor([3., 3., 0.], device='cuda:1')
Hessian vector product of d param: tensor([2., 2., 0.], device='cuda:0')
Done!

The block below is repro code.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def zero_grad(params):
    '''
    Clean the gradient of each parameter
    '''
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()


def collect_grad(params):
    '''
    Collect grads of parameters and concatenate them into a vector.
    If grad is None, it will be filled with zeros
    :param params: list of parameters
    :return: vector
    '''
    grad_list = []
    for p in params:
        if p.grad is not None:
            grad_list.append(p.grad.contiguous().view(-1))
            del p.grad
        else:
            # replace None with zeros
            grad_list.append(torch.zeros_like(p).view(-1))
    return torch.cat(grad_list)


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    if not args.distributed:
        z = torch.tensor([[2.0], [1.0]]).to(rank)
    elif rank == 0:
        z = torch.tensor([[1.0]]).to(rank)
    elif rank == 1:
        z = torch.tensor([[2.0]]).to(rank)

    loss = D(G(z)).mean()

    zero_grad(d_params)
    autograd.backward(gradvec_g,
                      grad_tensors=torch.ones_like(gradvec_g),
                      inputs=d_params)
    hvp_d = collect_grad(d_params)
    print(f'Hessian vector product of d param: {hvp_d}')
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')

Compute Jacobian vector product w.r.t G networks: as shown below, the Jacobian vector product is synchronized across devices.

Running on rank 0
Running on rank 1
Hessian vector product of g param: tensor([1.5000, 1.5000, 1.0000, 1.0000], device='cuda:1')
Hessian vector product of g param: tensor([1.5000, 1.5000, 1.0000, 1.0000], device='cuda:0')
Done!

The repro code only switches the order of derivative as attached below.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def zero_grad(params):
    '''
    Clean the gradient of each parameter
    '''
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()


def collect_grad(params):
    '''
    Collect grads of parameters and concatenate them into a vector.
    If grad is None, it will be filled with zeros
    :param params: list of parameters
    :return: vector
    '''
    grad_list = []
    for p in params:
        if p.grad is not None:
            grad_list.append(p.grad.contiguous().view(-1))
            del p.grad
        else:
            # replace None with zeros
            grad_list.append(torch.zeros_like(p).view(-1))
    return torch.cat(grad_list)


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    if not args.distributed:
        z = torch.tensor([[2.0], [1.0]]).to(rank)
    elif rank == 0:
        z = torch.tensor([[1.0]]).to(rank)
    elif rank == 1:
        z = torch.tensor([[2.0]]).to(rank)

    loss = D(G(z)).mean()

    grad_d = autograd.grad(loss, d_params, create_graph=True)
    gradvec_d = torch.cat([g.contiguous().view(-1) for g in grad_d])

    zero_grad(g_params)  # clean the grad before backward
    autograd.backward(gradvec_d,
                      grad_tensors=torch.ones_like(gradvec_d),
                      inputs=g_params)   # compute d{torch.dot(gradvec_d, vec)} / d{G}
    hvp_g = collect_grad(g_params)  # gather results

    print(f'Hessian vector product of g param: {hvp_g}')
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')

Update: there is some typo in the first block. But my question remains the same.

The first code block should be:

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def zero_grad(params):
    '''
    Clean the gradient of each parameter
    '''
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()


def collect_grad(params):
    '''
    Collect grads of parameters and concatenate them into a vector.
    If grad is None, it will be filled with zeros
    :param params: list of parameters
    :return: vector
    '''
    grad_list = []
    for p in params:
        if p.grad is not None:
            grad_list.append(p.grad.contiguous().view(-1))
            del p.grad
        else:
            # replace None with zeros
            grad_list.append(torch.zeros_like(p).view(-1))
    return torch.cat(grad_list)


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    if not args.distributed:
        z = torch.tensor([[2.0], [1.0]]).to(rank)
    elif rank == 0:
        z = torch.tensor([[1.0]]).to(rank)
    elif rank == 1:
        z = torch.tensor([[2.0]]).to(rank)

    loss = D(G(z)).mean()
    grad_g = autograd.grad(loss, d_params, create_graph=True)
    gradvec_g = torch.cat([g.contiguous().view(-1) for g in grad_g])
    zero_grad(d_params)
    autograd.backward(gradvec_g,
                      grad_tensors=torch.ones_like(gradvec_g),
                      inputs=d_params)
    hvp_d = collect_grad(d_params)
    print(f'Hessian vector product of d param: {hvp_d}')
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')