Bug in memory efficient DenseNet?

Dear PyTorchers,

I am trying to train the densenet model, implemented in torchvision, in the memory efficient mode (memory efficiency is obtained by trading computational complexity for memory via gradient checkpointing).

However, the minimal example below shows that some of the weighs are not being updated. Namely, the weights of conv1 are unchanged whereas those of conv2 are changed.

Is this a bug? Has anyone encountered this?

EDIT: I tried running test_checkpoint_module_list, defined in pytorch/test/test_utils.py, on densenet and the test fails.

Kind regards,
Vardan

import torch
from torchvision import datasets, transforms
from torchvision.models.densenet import _densenet

model = _densenet('densenet121', 32, (1, 1, 1, 1), 64,
                  pretrained=False, progress=True,
                  memory_efficient=True,
                  num_classes=10)

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

train_set = datasets.CIFAR10('data', train=True, download=True,
                             transform=train_transforms)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=4)

optimizer = torch.optim.SGD(model.parameters(), lr=1)

for batch in train_loader: break

input, target = batch

output = model(input)

loss = torch.nn.functional.cross_entropy(output, target)

conv1 = list(model.features)[4].denselayer1.conv1.weight.clone()
conv2 = list(model.features)[4].denselayer1.conv2.weight.clone()

optimizer.zero_grad()
loss.backward()
optimizer.step()

new_conv1 = list(model.features)[4].denselayer1.conv1.weight.clone()
new_conv2 = list(model.features)[4].denselayer1.conv2.weight.clone()

print(torch.norm(conv1 - new_conv1))
print(torch.norm(conv2 - new_conv2))
1 Like

I found the performance is quite different between with and without memory efficient implement by checkpoint in training process.

Do you know anything wrong?

@Marksein @papyan I just ran into this issue, founding this looking for confirmation, and pretty sure I fixed the bug

change the original code here…

    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(*inputs)

        return cp.checkpoint(closure, input)

to

    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(inputs) # remove *

        return cp.checkpoint(closure, *input) # add *

Very subtle!

2 Likes