DDP Second backward() accumulate the wrong gradient

Hello folks, I was trying to compute Hessian vector product by using backward twice with DDP. However, the second backward doesn’t seem to gather gradient as desired. The same code worked in single GPU but got wrong value in multi-GPU case (in my case 2 GPUs). This is also autograd related issue maybe also invite @albanD to take a look here.

Results on single GPU, which is correct :slight_smile: :

Running on rank 0
Hessian vector product:  tensor([1., 1., 1., 1.], device='cuda:0', grad_fn=<CatBackward>)
Done!

Results on two GPUs, which is wrong :frowning:

Running on rank 0
Running on rank 1
Hessian vector product:  tensor([0., 0., 0., 0.], device='cuda:0', grad_fn=<CatBackward>)
Hessian vector product:  tensor([0., 0., 0., 0.], device='cuda:1', grad_fn=<CatBackward>)
Done!

Here is the minimal repro code. I initialized the networks with constant so that everything is deterministic here.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def zero_grad(params):
    '''
    Clean the gradient of each parameter
    '''
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()


def collect_grad(params):
    '''
    Collect grads of parameters and concatenate them into a vector.
    If grad is None, it will be filled with zeros
    :param params: list of parameters
    :return: vector
    '''
    grad_list = []
    for p in params:
        if p.grad is not None:
            grad_list.append(p.grad.contiguous().view(-1))
            del p.grad
        else:
            # replace None with zeros
            grad_list.append(torch.zeros_like(p).view(-1))
    return torch.cat(grad_list)


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    z = torch.ones((2, 1)).to(rank)
    loss = D(G(z)).mean()

    loss.backward(create_graph=True)
    gradvec_d = collect_grad(d_params)   # d{loss} / d{D}

    zero_grad(g_params)  # clean the grad before backward
    autograd.backward(gradvec_d,
                      grad_tensors=torch.ones_like(gradvec_d),
                      inputs=g_params)   # compute d{torch.dot(gradvec_d, vec)} / d{G}
    hvp = collect_grad(g_params)  # gather results

    print('Hessian vector product: ', hvp)
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')

I just realize that reason is probably DPP does gradient gathering operation across all devices every time when backward() is called, which does not define grad_fn and break the computation graph such that gradient will not be accumulated in the second backward().
Do you folks have any idea about how to get around with this? Any thought will be appreciated.

Update:

I eventually managed to get around this by using autograd.grad for the first backward because autograd.grad won’t trigger DDP gradient synchronization. I attached my code below for people who have the same problem.

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.autograd as autograd

from utils.helper import setup, cleanup
from argparse import ArgumentParser


def subprocess_fn(rank, args):
    setup(rank, args.num_gpus)
    print(f'Running on rank {rank}')
    D = nn.Linear(2, 1, bias=True).to(rank)
    G = nn.Linear(1, 2, bias=True).to(rank)
    # initialize weights
    nn.init.constant_(D.weight, 2.0)
    nn.init.constant_(D.bias, -1.0)
    nn.init.constant_(G.weight, 4.0)
    nn.init.constant_(G.bias, 1.0)

    if args.distributed:
        G = DDP(G, device_ids=[rank], broadcast_buffers=False)
        D = DDP(D, device_ids=[rank], broadcast_buffers=False)

    d_params = list(D.parameters())
    g_params = list(G.parameters())

    z = torch.ones((2, 1)).to(rank)
    loss = D(G(z)).mean()

    grad_d = autograd.grad(loss, d_params, create_graph=True)
    gradvec_d = torch.cat([g.contiguous().view(-1) for g in grad_d])
    autograd.backward(gradvec_d,
                      grad_tensors=torch.ones_like(gradvec_d),
                      inputs=g_params)   # compute d{torch.dot(gradvec_d, vec)} / d{G}
    hvp = collect_grad(g_params)  # gather results

    print('Hessian vector product: ', hvp)
    cleanup()


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser()
    parser.add_argument('--num_gpus', type=int, help='Number of GPUs', default=1)
    args = parser.parse_args()
    args.distributed = args.num_gpus > 1

    if args.distributed:
        mp.spawn(subprocess_fn, args=(args, ), nprocs=args.num_gpus)
    else:
        subprocess_fn(0, args)

    print('Done!')
2 Likes

Hey!

Interesting investigation.
Could you actually open an issue on github about this? Asking to add support natively (or raise a nice error if it doesn’t) and document workarounds.

Thanks!

1 Like

Sure! That would be great.

@Hongkai_Zheng
Thank you for your method. However, how to adopt autograd in amp training?

class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)

When

loss_scaler_A(loss_A, optimizer_A, clip_grad=max_norm,
              parameters=model_A.parameters(), create_graph=True,
              update_grad=(data_iter_step + 1) % accum_iter == 0)

loss_scaler_B(loss_B, optimizer_B, clip_grad=max_norm,
              parameters=model_B.parameters(), create_graph=False,
              update_grad=(data_iter_step + 1) % accum_iter == 0)

It is noted that the feature maps of B are used as the input of A, so I need the gradient flow can pass into B. To achieve this, I use create_graph=True. Here self._scaler.scale(loss).backward(create_graph=create_graph) uses backward instead of autograd. Could you help me on this?

Just for your inference, I asked gpt-4 about my issue. I listed all the code you provided and the essential words of your thought. And gpt-4 gave me some quite reasonable code (not strictly verifed yet):

class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self, loss, optimizer, parameters, clip_grad=None, create_graph=False, update_grad=True):
        # Convert parameters to a list to avoid exhausting the generator
        # and filter out parameters that don't require gradients
        param_list = [p for p in parameters if p.requires_grad]

        scaled_loss = self._scaler.scale(loss)

        if create_graph:
            # Calculate gradients while retaining the graph
            grads = torch.autograd.grad(scaled_loss, param_list, create_graph=True)
            # Assign the gradients to the parameters that require grad
            for param, grad in zip(param_list, grads):
                if param.grad is not None:
                    param.grad += grad
                else:
                    param.grad = grad
        else:
            # For the standard backward pass
            scaled_loss.backward()

        if update_grad and not create_graph:
            # Unscale and step only if not in create_graph mode
            self._scaler.unscale_(optimizer)
            if clip_grad is not None:
                assert parameters is not None 
                torch.nn.utils.clip_grad_norm_(param_list, clip_grad)
            self._scaler.step(optimizer)
            self._scaler.update()

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


# Initialize the scaler instances
loss_scaler_A = NativeScalerWithGradNormCount()
loss_scaler_B = NativeScalerWithGradNormCount()

# Example training loop iteration
for data_iter_step, data in enumerate(data_loader):
    optimizer_A.zero_grad()
    optimizer_B.zero_grad()

    # Compute losses for model A and B
    loss_A = compute_loss(model_A(data))
    loss_B = compute_loss(model_B(data))

    # Call the scaler for model A
    loss_scaler_A(loss_A, optimizer_A, clip_grad=max_norm,
                  parameters=model_A.parameters(), create_graph=True,
                  update_grad=(data_iter_step + 1) % accum_iter == 0)

    # Call the scaler for model B
    loss_scaler_B(loss_B, optimizer_B, clip_grad=max_norm,
                  parameters=model_B.parameters(), create_graph=False,
                  update_grad=(data_iter_step + 1) % accum_iter == 0)

    # ... rest of your training loop ...