Memory Error while trying to save resnet features to a dictionary

I’m fairly new to PyTorch and I’m trying to extract the features of my training data using a ResNet101 and save them to a dictionary. I will then intend to use this dictionary as a lookup table for another model.

However, I keep running into memory problems after only processing 8 images on a 12GB GPU machine. The for loop is:

splits = ['train2014', 'val2014', 'test2015']
for split in splits:
    RES_FILE = os.path.join(D_ROOT, 'BaselineTraining', split, 'baseline_{}_cnn_features.pth'.format(split))
    IMAGE_DIR = os.path.join(D_ROOT, 'Images', 'mscoco', split)
    images = os.listdir(IMAGE_DIR)
    images_path = [os.path.join(IMAGE_DIR, file) for file in images]
    res = torch.load(RES_FILE) if os.path.exists(RES_FILE) else {}
    for i in progressbar.progressbar(range(len(images))):
        #Checkpoint every 2,500 images
        if i % 2500 == 0 and i > 0:
            torch.save(res, RES_FILE)
        image_name, image = images[i], cv2.imread(images_path[i])
        if image_name in res:
            continue
        image = process_image(image)
        image = np.expand_dims(image, axis=0)
        image = torch.from_numpy(image)
        _, H, W, C = list(image.size())
        image = image.view(1, C, H, W)
        image = image.float()
        image = image.cuda()
        out = resnet101.forward(image)
       #res[image_name] = out.cpu()
    print('Successfully processed {}.'.format(split))
print('Successfully processed all images.')

Also, I have noticed that the problem only persists when I uncomment the commented line. Without it, extraction goes smoothly. What is happening here?

Hi,

What happens is that you seem to run this code with the autograd enabled.
Since the parameters in the resnet require gradients, your out does as well (and so contains all the history to compute these gradients if needed). When you save out.cpu() in your dictionary, you still hold onto this history. So the memory usage increases.
The best thing to do here I think is to decorate your function with @torch.no_grad() or run the model within the with torch.no_grad(): context manager. to disable autograd.

PS: you should not call the .forward() method of modules directly, but call them with the input: out = resnet101(image) in your case. Otherwise features like hooks won’t work.

Thanks! It worked like a charm. Memory usage is now more or less constant :slight_smile:

1 Like