Gradients computed by backward sum of losses is not equal to backward each loss one by one

Hi, suppose we have two losses, I tried to sum these losses first and then do the backward / backward them one by one. I think that the gradients wrt a layer should be equal. But I found that it is different.


cudnn.deterministic = True
cudnn.benchmark = False

torch.manual_seed(2)
torch.set_printoptions(precision=16)

a = torch.randn(2, 3, 10, 10).cuda()
label = torch.tensor([0, 1]).cuda()

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.l1 = nn.Conv2d(3, 10, 3, 1, 1)
        self.l2 = nn.Conv2d(10, 2, 3, 1, 1)

    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)

        x = x.view(x.size(0), -1)
        return x


model = Model().cuda()


model.zero_grad()

x = torch.randn_like(a)
y = model(x)

loss = y.sum(dim=1)
print(loss)

loss.sum().backward(retain_graph=True)
grad_l1_1 = model.l1.weight.grad.clone()
grad_l2_1 = model.l2.weight.grad.clone()
model.zero_grad()

loss[0].backward(retain_graph=True)
loss[1].backward(retain_graph=True)
grad_l1_2 = model.l1.weight.grad.clone()
grad_l2_2 = model.l2.weight.grad.clone()

diff_l1 = (grad_l1_2 - grad_l1_1).abs().sum()
diff_l2 = (grad_l2_2 - grad_l2_1).abs().sum()

print(diff_l1)
print(diff_l2)

tensor(2.1591782569885254e-05, device='cuda:0')
tensor(0., device='cuda:0')

I think the reason might be the float issue. But it’s just a two layers network. is that normal to get so different grads like this?

Thanks!

And, If we do it for a deeper network such as resnet50 like this:


a = torch.randn(2, 3, 10, 10).cuda()
label = torch.tensor([0, 1]).cuda()


class Model(ResNet):
    def __init__(self) -> None:
        super().__init__(Bottleneck, [3, 4, 6, 3])

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        f = torch.flatten(x, 1)
        x = self.fc(f)

        return x, f


model = Model().cuda()

model.zero_grad()

x = torch.randn_like(a)
y, f = model(x)

loss = F.cross_entropy(y, label, reduction="none")

print(loss)

loss.sum().backward(retain_graph=True)
grad_l1_1 = model.conv1.weight.grad.clone()
model.zero_grad()

loss[0].backward(retain_graph=True)
loss[1].backward(retain_graph=True)
grad_l1_2 = model.conv1.weight.grad.clone()

diff_l1 = (grad_l1_2 - grad_l1_1).abs().sum()
print(diff_l1)

we get a significant difference for the grads of the first conv layer:


tensor(5.1960111815801071e+20, device='cuda:0')

The first difference would be expected and caused by the limited floating point precision.
In your second model you are also comparing the absolute errors instead of the relative ones. Since the gradients have already a huge magnitude:

grad_l1_1.abs().max()
tensor(5.4373049958400000e+11, device='cuda:0')

also these errors might be expected:

diff_l1 = (grad_l1_2 - grad_l1_1).abs().sum() / grad_l1_1.abs().sum()
print(diff_l1)
# tensor(0.0017766774399206, device='cuda:0')

Thanks for your reply!