Grad accumulation, freeing parts of the graph, and DataParallel

Hi, I have an odd use case that I need help with, however, my question mostly depends on my lack of understanding of how the backward pass works across multiple gpus.

I want to be able to do gradient accumulation over just part of the model instead of splitting accumulation over the batch dimension of the entire model and to do this in a multi-gpu setting. The reason why I want to do this is because this part of the model has a large memory requirement but is fast to compute while the other part of the model has a low memory requirement but is slow to compute. I can get this partial grad accumulation to work in a single GPU environment, but it requires that I call .backward() in a loop within my forward pass in order to free up the graph buffers (hence saving memory).

Essentially, regular gradient accumulation with multiple gpus works placing the gpu split inside the accumulation loop and, instead, what I want to do is place the accumulation loop inside the gpu split.

Now I do not understand the how the backward pass works across multiple gpus to know if this would work, where if I have already accumulated the gradient for the necessary model parameters on a single GPU via backward() calls, will those accumulated gradients be scattered back correctly across the GPUs? Will this work with DataParallel? If not will it work with DistributedDataParallel?

I have attached code to explain how I’m doing partial accumulated gradient on a single gpu, specifically ‘accum_forward’ in model2. The code shows the gradients are the same for all four different ways they were computed.

import torch
from torch import nn
import numpy as np


"""
My accum does not work with pytorch 0.41 but does work in >=1.7
"""


def set_random():
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)


def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m, 0.0, .001)


class Model1(torch.nn.Module):
    def __init__(self, v, d):
        self.d = d
        super(Model1, self).__init__()

        self.emb = torch.nn.Embedding(v, d)
        self.linear1 = torch.nn.Linear(self.d, self.d // 2)
        self.linear2 = torch.nn.Linear(self.d // 2, v)
        weight_init(self)

    def forward(self, inputs, targets):
        e = self.emb(inputs)
        logits = self.linear2(self.linear1(e))
        cost = torch.nn.functional.cross_entropy(logits, targets)
        return cost, logits


class Model2(torch.nn.Module):
    def __init__(self, v, d):
        self.v = v
        super(Model2, self).__init__()
        self.linear = torch.nn.Linear(self.v, self.v)
        weight_init(self)

    def forward(self, base_logits, targets):
        logits = self.linear(base_logits)
        cost = torch.nn.functional.cross_entropy(logits, targets)
        return cost

    def accum_forward(self,  base_logits, targets, num=10):
        base_logits_z = base_logits.detach()
        base_logits_z.requires_grad_(True)

        base_logits_chunks = torch.chunk(base_logits_z, num, 0)
        targets_chunks = torch.chunk(targets, num, 0)
        cost_total = 0.
        for i in range(len(targets_chunks)):
            logits_i = self.linear(base_logits_chunks[i])
            cost = torch.nn.functional.cross_entropy(logits_i, targets_chunks[i])
            cost_total += cost
            cost /= len(targets_chunks)
            cost.backward()
        cost_total /= float(len(targets_chunks))
        base_logits.backward(gradient=base_logits_z.grad, retain_graph=True)
        return cost_total

    def retain_accum_forward(self,  base_logits, targets, num=10):
        base_logits_chunks = torch.chunk(base_logits, num, 0)
        targets_chunks = torch.chunk(targets, num, 0)
        cost_total = 0.
        for i in range(len(targets_chunks)):
            logits_i = self.linear(base_logits_chunks[i])
            cost = torch.nn.functional.cross_entropy(logits_i, targets_chunks[i])
            cost_total += cost
            cost /= len(targets_chunks)
            cost.backward(retain_graph=True)
        cost_total /= float(len(targets_chunks))
        # cost_total.backward(retain_graph=True)
        return cost_total


def _test():
    bs = 600
    v = 1000
    # optim = torch.optim.SGD()

    set_random()
    inputs = torch.randint(0, v, [bs]).cuda().long()
    targets = torch.randint(0, v, [bs]).cuda().long()

    print('\nRegular version')
    set_random()
    m1 = Model1(v, d=256).cuda()
    m2 = Model2(v, d=256).cuda()
    cost1, logits = m1(inputs, targets)
    cost2 = m2(logits, targets)
    cost2.backward(retain_graph=True)
    cost1.backward()
    print(m1.emb.weight.grad[:1, :10])
    print(m1.linear1.weight.grad[:3, :10])
    print(m2.linear.weight.grad[:3, :10])
    print(cost1, cost2)

    print('\nRegular accum version')
    set_random()
    m1 = Model1(v, d=256).cuda()
    m2 = Model2(v, d=256).cuda()

    inputs_chunks = torch.chunk(inputs, 10, 0)
    targets_chunks = torch.chunk(targets, 10, 0)
    total_cost1 = 0.
    total_cost2 = 0.
    for i in range(len(inputs_chunks)):
        cost1, logits = m1(inputs_chunks[i], targets_chunks[i])
        cost2 = m2(logits, targets_chunks[i])
        (cost1/len(inputs_chunks) + cost2/len(inputs_chunks)).backward()
        #(cost2/len(inputs_chunks)).backward(retain_graph=True)
        #(cost1/len(inputs_chunks)).backward()  # retain_graph=True)
        total_cost1 += cost1
        total_cost2 += cost2
    total_cost1 /= len(inputs_chunks)
    total_cost2 /= len(inputs_chunks)
    # (total_cost2).backward(retain_graph=True)
    # (total_cost1 + total_cost2).backward()
    print(m1.emb.weight.grad[:1, :10])
    print(m1.linear1.weight.grad[:3, :10])
    print(m2.linear.weight.grad[:3, :10])
    print(total_cost1, total_cost2)

    print('\nMy Accum')
    set_random()
    m1 = Model1(v, d=256).cuda()
    m2 = Model2(v, d=256).cuda()
    cost1, logits = m1(inputs, targets)
    cost2 = m2.accum_forward(logits, targets)
    cost1.backward()
    print(m1.emb.weight.grad[:1, :10])
    print(m1.linear1.weight.grad[:3, :10])
    print(m2.linear.weight.grad[:3, :10])
    print(cost1, cost2)

    print('\nRetain Accum')
    set_random()
    m1 = Model1(v, d=256).cuda()
    m2 = Model2(v, d=256).cuda()
    cost1, logits = m1(inputs, targets)
    cost2 = m2.retain_accum_forward(logits, targets)
    cost1.backward()
    print(m1.emb.weight.grad[:1, :10])
    print(m1.linear1.weight.grad[:3, :10])
    print(m2.linear.weight.grad[:3, :10])
    print(cost1, cost2)



if __name__ == '__main__':
    _test()

We recommend to use DistributedDataParallel over DataParallel. In DistributedDataParallel, you can use the no_sync context manager to disable gradient synchronization and accumulate gradients, once the # of steps for gradient accumulation is satisfied, exiting the context will re-trigger synchronization in the next backwards pass.

1 Like

Thanks, that’s what I ended up doing. The code base was inherited with DataParallel, which is why I was hoping it could be used, but moving to DistributedDataParallel will be better in the long run.