Manually edit tensor grad before DDP gradient sync

I wonder if it is valid to manually edit a tensor’s grad of a DDP model before syncing the gradient. This is what I am trying to do:

1. base_model = MyModel().cuda(device_ids[0])
2. ddp_model = DDP(base_model, device_ids)

3. outputs = ddp_model(input)
4. loss1, loss2 = loss_fn(outputs)

5. with ddp_model.no_sync():
6.     local_loss = loss1 + 0. * loss2
7.     local_loss.backward(retain_graph=True)
8.     for p in base_model.sub_model2.parameters():
9.         p.grad *= 0.
10.     for p in base_model.sub_model3.parameters():
11.         p.grad *= 0.

12. local_loss = 0. * loss1 + loss2
13. local_loss.backward()

14. optimizer.step()

As shown in the code snippet above, I manually modify the base_model’s gradient at lines 8 - 11 before syncing the gradient at line 13. My goal is to use loss1 to only update sub_model1 in the base_model, and use loss2 to update the whole base_model.

The code runs without error, but I am concerned if this manual modification of tensor gradient will cause any issue to the gradient sync mechanism in DDP.

Hey @albert.cwkuo

With the above code, I think DDP still syncs all grads for both loss1 and loss2, because the flag controlled by no_sync ctx manager is used when calling DistributedDataParallel.forward(). So, as the forward is out of the no_sync context, DDP would still prepare to sync all grads during the backward pass.

How is MyModel implemented. Does it contain two independent submodules sub_model1 and sub_model2 and do sth like the following?

class MyModel(nn.Module):
    def __init__(self):
        self.sub_model1 = SomeModel1()
        self.sub_model2 = SomeModel2()

    def forward(self, input):
        return self.sub_model1(input), self.sub_model2(input)
1 Like

Thanks @mrshenli for your reply. This is what MyModel do internally.

class MyModel(nn.Module):
    def __init__(self):
        self.sub_model1 = SomeModel1()
        self.sub_model2 = SomeModel2()

    def forward(self, input):
        out1 = self.sub_model2(input)
        out2 = self.sub_model1(out1)
        return out1, out2

I want the gradient to be synced eventually when I call backward() at line 11, so the above code seems correct? My goal is to accumulate gradient from both loss1 and loss2 to sub_model2 and accumulate gradient only from loss2 to sub_model1. That’s why I try to zero out grad in sub_model1.

Note that I use local_loss = loss1 + 0.0 * loss2 and local_loss = 0.0 * loss1 + loss2 to mask out part of the loss before calling local_loss.backward().

----------------update------------

class MyModel(nn.Module):
    def __init__(self):
        self.sub_model1 = SomeModel1()
        self.sub_model2 = SomeModel2()
        self.sub_model3 = SomeModel3()

    def forward(self, input):
        out1 = self.sub_model2(self.sub_model1(input))
        out2 = self.sub_model3(out1)
        return out1, out2

My goal is to accumulate gradient from both loss1 and loss2 to sub_model1 and accumulate gradient only from loss2 to sub_model2 and sub_model3. That’s why I try to zero out grad in sub_model2 and sub_model3.

In that case, calling backward once on loss1+loss2 might be sufficient? Is the following result what you want?

import torch
import torch.nn as nn

class MyModel(nn.Module):

    def __init__(self):
        super().__init__()
        with torch.no_grad():
            self.net1 = nn.Linear(1, 1)
            self.net1.weight.copy_(torch.ones(1, 1))
            self.net1.bias.copy_(torch.zeros(1))
            self.net2 = nn.Linear(1, 1)
            self.net2.weight.copy_(torch.ones(1, 1))
            self.net2.bias.copy_(torch.zeros(1))

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(out1)
        return out1 + out2


print("==============")
model = MyModel()
model(torch.ones(1, 1)).sum().backward()
print("net1 grad is: ", model.net1.weight.grad)
print("net2 grad is: ", model.net2.weight.grad)

print("==============")
model = MyModel()
model.net1(torch.ones(1, 1)).sum().backward()
print("net1 grad is: ", model.net1.weight.grad)
print("net1 grad is: ", model.net2.weight.grad)

print("==============")
model = MyModel()
model.net2(torch.ones(1, 1)).sum().backward()
print("net1 grad is: ", model.net1.weight.grad)
print("net2 grad is: ", model.net2.weight.grad)

outputs are

==============
net1 grad is:  tensor([[2.]])
net2 grad is:  tensor([[1.]])
==============
net1 grad is:  tensor([[1.]])
net2 grad is:  None
==============
net1 grad is:  None
net2 grad is:  tensor([[1.]])

Okay that makes sense, but let me clarify how MyModel works internally so that the proposed solution may not work in the case. The reply has been updated.

In short, I have 3 subnets inside MyModel. One of the loss depends on sub_model1 and sub_model2, but I only want to update sub_model1 with this loss. Therefore I need to zero out the grad in sub_model2 when calling backward on that loss.

With the above statement, should the forward function be something like below?

    def forward(self, input):
        out1 = self.sub_model1(input)
        out2 = self.sub_model3(self.sub_model2(out1))
        return out1, out2

With this code, out1.sum().backward() will only compute grads for sub_model1, and out2.sum().backward() will compute grads for all sub-models. And (out1 + out2).sum().backward() should meet the cited statement above.

In my use case it’s

    def forward(self, input):
        out1 = self.sub_model2(self.sub_model1(input))
        out2 = self.sub_model3(out1)
        return out1, out2

That’s the tricky part. That’s why I want to zero out the gradient of sub_mode2 after calling backward on loss1 (loss1 is computed on out1).