Different gradients with/without using torch.nn.utils.checkpoint in the model

Different gradients with/without using torch.nn.utils.checkpoint in the model

I wrote a small test to check if I use torch.nn.utils.checkpoint in the right way. But, strangely, when I call checkpoint to run the first Conv2d layer, I realized that the gradient update is different from the similar model which does not call checkpoint at all. I am not sure where goes wrong. Thanks

Following is my experiment code running on Pytorch 1.5.1:

import torch
import copy
from torch import nn
from torch.utils.checkpoint import checkpoint

class submodule(nn.Module):
    def __init__(self):
        super(submodule, self).__init__()
        self.conv = nn.Conv2d(in_channels=1,out_channels=1, kernel_size=3, padding=1)
        self.norm = nn.BatchNorm2d(1)
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        return x
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=1, kernel_size=3, padding=1)
        self.conv2 = submodule()
        self.fc1 = nn.Linear(1*3*3, 3)
    def run_func(self, cell):
        def call_cell(*input):
            return cell(input[0])
        return call_cell
    def forward(self, x):
        dummy_input = torch.rand([1], requires_grad=True)
        x = checkpoint(self.run_func(self.conv1), x, dummy_input)
        x = self.conv2(x)    
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return x
class Net2(Net):
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)    
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return x

def checking(model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    parameter_result = []
    for i in range(499):
        ref = model(torch.rand([3,1,3,3]))
        hyp = torch.eye(num_classes)[torch.randint(low=0, high=3, size=[3])]
        criteria = torch.nn.MSELoss()
        loss = criteria(ref, hyp)
        if i % 100 == 0:
    return parameter_result

model1 = Net()
model2 = Net2()

parameter_result1 = checking(model1)
parameter_result2 = checking(model2)

for index, (a,b) in enumerate(zip(parameter_result1, parameter_result2)):
    for w in b.keys():
        assert torch.allclose(a[w], b[w]), "{}\n{}\n{}\n{}".format(w, a[w], b[w], a[w]-b[w])


Do I understand correctly that your issue is that having a checkpoint module in your model makes it non-deterministic?
But running the same code without the checkpoint, the model is deterministic?
Running your code without the checkpoint: x = self.run_func(self.conv1)(x, dummy_input) gives the same result for me.

Hi albanD,

Sorry for the late responding. I am not sure how the determinism status influence the back propagation.
The only difference between these two model is whether use checkpoint on the Conv2 layer.

My understanding of checkpoint is it saves some magic states in order to rerun a layer’s forward pass when calculating that layer’s gradients. This solution saves GPU memory in exchange of computation power as one always recompute intermediate data instead of saving these data for backward pass. And the gradients should always be the same whether we use checkpoint() or not.

I expect the gradient of both models should be the same as long as I feed them the same inputs and outputs labels.
Since I initialize them with the same weights, the weights of these two models should be similar after weights updating, in my imagination. But, what I get is big difference.

For example, following is the difference of Conv1 layer. Most weights difference are small ( <10^7 ), but two of them is very big ( ==0.2). This big difference does not make sense to me.

AssertionError: conv1.weight

# call checkpoint() on Conv2
tensor([[[[-0.1025,  0.2788, -0.3743],
          [-0.1453, -0.0284, -0.0106],
          [ 0.0934,  0.3643, -0.1296]]]])
# not call checkpoint()
tensor([[[[-0.1025,  0.2788, -0.3743],
          [-0.1453, -0.0284, -0.0106],
          [-0.1066,  0.1643, -0.1296]]]])

# weights difference of Conv1 between two models
tensor([[[[-1.4901e-08,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  2.2352e-08],
          [2.0000e-01,  2.0000e-01,  0.0000e+00]]]])

Thanks for your time