Different gradients with/without using torch.nn.utils.checkpoint in the model
I wrote a small test to check if I use torch.nn.utils.checkpoint in the right way. But, strangely, when I call checkpoint to run the first Conv2d layer, I realized that the gradient update is different from the similar model which does not call checkpoint at all. I am not sure where goes wrong. Thanks
Following is my experiment code running on Pytorch 1.5.1:
import torch
import copy
from torch import nn
from torch.utils.checkpoint import checkpoint
class submodule(nn.Module):
def __init__(self):
super(submodule, self).__init__()
self.conv = nn.Conv2d(in_channels=1,out_channels=1, kernel_size=3, padding=1)
self.norm = nn.BatchNorm2d(1)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1,out_channels=1, kernel_size=3, padding=1)
self.conv2 = submodule()
self.fc1 = nn.Linear(1*3*3, 3)
def run_func(self, cell):
def call_cell(*input):
return cell(input[0])
return call_cell
def forward(self, x):
dummy_input = torch.rand([1], requires_grad=True)
x = checkpoint(self.run_func(self.conv1), x, dummy_input)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
return x
class Net2(Net):
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
return x
def checking(model):
torch.manual_seed(0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
parameter_result = []
parameter_result.append(copy.deepcopy(model.state_dict()))
num_classes=3
for i in range(499):
model.zero_grad()
ref = model(torch.rand([3,1,3,3]))
hyp = torch.eye(num_classes)[torch.randint(low=0, high=3, size=[3])]
criteria = torch.nn.MSELoss()
loss = criteria(ref, hyp)
loss.backward()
optimizer.step()
if i % 100 == 0:
parameter_result.append(copy.deepcopy(model.state_dict()))
parameter_result.append(copy.deepcopy(model.state_dict()))
return parameter_result
torch.manual_seed(0)
model1 = Net()
torch.manual_seed(0)
model2 = Net2()
parameter_result1 = checking(model1)
parameter_result2 = checking(model2)
for index, (a,b) in enumerate(zip(parameter_result1, parameter_result2)):
print(index)
for w in b.keys():
assert torch.allclose(a[w], b[w]), "{}\n{}\n{}\n{}".format(w, a[w], b[w], a[w]-b[w])