Dear PyTorchers,
I am trying to train the densenet model, implemented in torchvision, in the memory efficient mode (memory efficiency is obtained by trading computational complexity for memory via gradient checkpointing).
However, the minimal example below shows that some of the weighs are not being updated. Namely, the weights of conv1 are unchanged whereas those of conv2 are changed.
Is this a bug? Has anyone encountered this?
EDIT: I tried running test_checkpoint_module_list, defined in pytorch/test/test_utils.py, on densenet and the test fails.
Kind regards,
Vardan
import torch
from torchvision import datasets, transforms
from torchvision.models.densenet import _densenet
model = _densenet('densenet121', 32, (1, 1, 1, 1), 64,
pretrained=False, progress=True,
memory_efficient=True,
num_classes=10)
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
])
train_set = datasets.CIFAR10('data', train=True, download=True,
transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4)
optimizer = torch.optim.SGD(model.parameters(), lr=1)
for batch in train_loader: break
input, target = batch
output = model(input)
loss = torch.nn.functional.cross_entropy(output, target)
conv1 = list(model.features)[4].denselayer1.conv1.weight.clone()
conv2 = list(model.features)[4].denselayer1.conv2.weight.clone()
optimizer.zero_grad()
loss.backward()
optimizer.step()
new_conv1 = list(model.features)[4].denselayer1.conv1.weight.clone()
new_conv2 = list(model.features)[4].denselayer1.conv2.weight.clone()
print(torch.norm(conv1 - new_conv1))
print(torch.norm(conv2 - new_conv2))