Hi, I have an odd use case that I need help with, however, my question mostly depends on my lack of understanding of how the backward pass works across multiple gpus.

I want to be able to do gradient accumulation over just part of the model instead of splitting accumulation over the batch dimension of the entire model and to do this in a multi-gpu setting. The reason why I want to do this is because this part of the model has a large memory requirement but is fast to compute while the other part of the model has a low memory requirement but is slow to compute. I can get this partial grad accumulation to work in a single GPU environment, but it requires that I call .backward() in a loop within my forward pass in order to free up the graph buffers (hence saving memory).

Essentially, regular gradient accumulation with multiple gpus works placing the gpu split inside the accumulation loop and, instead, what I want to do is place the accumulation loop inside the gpu split.

Now I do not understand the how the backward pass works across multiple gpus to know if this would work, where if I have already accumulated the gradient for the necessary model parameters on a single GPU via backward() calls, will those accumulated gradients be scattered back correctly across the GPUs? Will this work with DataParallel? If not will it work with DistributedDataParallel?

I have attached code to explain how I’m doing partial accumulated gradient on a single gpu, specifically ‘accum_forward’ in model2. The code shows the gradients are the same for all four different ways they were computed.

```
import torch
from torch import nn
import numpy as np
"""
My accum does not work with pytorch 0.41 but does work in >=1.7
"""
def set_random():
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m, 0.0, .001)
class Model1(torch.nn.Module):
def __init__(self, v, d):
self.d = d
super(Model1, self).__init__()
self.emb = torch.nn.Embedding(v, d)
self.linear1 = torch.nn.Linear(self.d, self.d // 2)
self.linear2 = torch.nn.Linear(self.d // 2, v)
weight_init(self)
def forward(self, inputs, targets):
e = self.emb(inputs)
logits = self.linear2(self.linear1(e))
cost = torch.nn.functional.cross_entropy(logits, targets)
return cost, logits
class Model2(torch.nn.Module):
def __init__(self, v, d):
self.v = v
super(Model2, self).__init__()
self.linear = torch.nn.Linear(self.v, self.v)
weight_init(self)
def forward(self, base_logits, targets):
logits = self.linear(base_logits)
cost = torch.nn.functional.cross_entropy(logits, targets)
return cost
def accum_forward(self, base_logits, targets, num=10):
base_logits_z = base_logits.detach()
base_logits_z.requires_grad_(True)
base_logits_chunks = torch.chunk(base_logits_z, num, 0)
targets_chunks = torch.chunk(targets, num, 0)
cost_total = 0.
for i in range(len(targets_chunks)):
logits_i = self.linear(base_logits_chunks[i])
cost = torch.nn.functional.cross_entropy(logits_i, targets_chunks[i])
cost_total += cost
cost /= len(targets_chunks)
cost.backward()
cost_total /= float(len(targets_chunks))
base_logits.backward(gradient=base_logits_z.grad, retain_graph=True)
return cost_total
def retain_accum_forward(self, base_logits, targets, num=10):
base_logits_chunks = torch.chunk(base_logits, num, 0)
targets_chunks = torch.chunk(targets, num, 0)
cost_total = 0.
for i in range(len(targets_chunks)):
logits_i = self.linear(base_logits_chunks[i])
cost = torch.nn.functional.cross_entropy(logits_i, targets_chunks[i])
cost_total += cost
cost /= len(targets_chunks)
cost.backward(retain_graph=True)
cost_total /= float(len(targets_chunks))
# cost_total.backward(retain_graph=True)
return cost_total
def _test():
bs = 600
v = 1000
# optim = torch.optim.SGD()
set_random()
inputs = torch.randint(0, v, [bs]).cuda().long()
targets = torch.randint(0, v, [bs]).cuda().long()
print('\nRegular version')
set_random()
m1 = Model1(v, d=256).cuda()
m2 = Model2(v, d=256).cuda()
cost1, logits = m1(inputs, targets)
cost2 = m2(logits, targets)
cost2.backward(retain_graph=True)
cost1.backward()
print(m1.emb.weight.grad[:1, :10])
print(m1.linear1.weight.grad[:3, :10])
print(m2.linear.weight.grad[:3, :10])
print(cost1, cost2)
print('\nRegular accum version')
set_random()
m1 = Model1(v, d=256).cuda()
m2 = Model2(v, d=256).cuda()
inputs_chunks = torch.chunk(inputs, 10, 0)
targets_chunks = torch.chunk(targets, 10, 0)
total_cost1 = 0.
total_cost2 = 0.
for i in range(len(inputs_chunks)):
cost1, logits = m1(inputs_chunks[i], targets_chunks[i])
cost2 = m2(logits, targets_chunks[i])
(cost1/len(inputs_chunks) + cost2/len(inputs_chunks)).backward()
#(cost2/len(inputs_chunks)).backward(retain_graph=True)
#(cost1/len(inputs_chunks)).backward() # retain_graph=True)
total_cost1 += cost1
total_cost2 += cost2
total_cost1 /= len(inputs_chunks)
total_cost2 /= len(inputs_chunks)
# (total_cost2).backward(retain_graph=True)
# (total_cost1 + total_cost2).backward()
print(m1.emb.weight.grad[:1, :10])
print(m1.linear1.weight.grad[:3, :10])
print(m2.linear.weight.grad[:3, :10])
print(total_cost1, total_cost2)
print('\nMy Accum')
set_random()
m1 = Model1(v, d=256).cuda()
m2 = Model2(v, d=256).cuda()
cost1, logits = m1(inputs, targets)
cost2 = m2.accum_forward(logits, targets)
cost1.backward()
print(m1.emb.weight.grad[:1, :10])
print(m1.linear1.weight.grad[:3, :10])
print(m2.linear.weight.grad[:3, :10])
print(cost1, cost2)
print('\nRetain Accum')
set_random()
m1 = Model1(v, d=256).cuda()
m2 = Model2(v, d=256).cuda()
cost1, logits = m1(inputs, targets)
cost2 = m2.retain_accum_forward(logits, targets)
cost1.backward()
print(m1.emb.weight.grad[:1, :10])
print(m1.linear1.weight.grad[:3, :10])
print(m2.linear.weight.grad[:3, :10])
print(cost1, cost2)
if __name__ == '__main__':
_test()
```