GPU memory for model in train mode is enough, but not enough for model in eval mode. Why?

GPU memory for model in train mode is enough, but not enough for model in eval mode. Why?

Usually, it would be the other way around.

You may need to check the batch size for test mode (if it is higher than training) as well as if you are testing with torch.no_grad().

If the reason is not one of the mentioned above, you may need to post any short snippet code with dataloader/test method to get a valid answer.

Dataloader is defined as followes:

import json
import h5py
import os
from PIL import Image
from PIL.ImageOps import expand
import numpy as np
import torch.utils.data as data
import multiprocessing
import random
import torchvision
from random import choice
import torch


train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(260),
                                                     torchvision.transforms.RandomResizedCrop(224),
                                                     torchvision.transforms.RandomHorizontalFlip(),
                                                     torchvision.transforms.ColorJitter(random.randint(0, 1)),
                                                     torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(224),
                                                    torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

class DataLoader(data.Dataset):
    def reset_iterator(self, split):
        del self._prefetch_process[split]
        self._prefetch_process[split] = BlobFetcher(split,
                                                    self, split == 'train')
        self.iterators[split] = 0

    def get_vocab_size(self):
        return self.vocab_size

    def get_vocab(self):
        return self.ix_to_word

    def word_to_ix(self):
        w2ix = {}
        for k, v in self.ix_to_word.items():
            w2ix[v] = k
        return w2ix

    def get_vb_vocab(self):
        vocab = {}
        for k, v in self.vb_ix_to_word.items():
            vocab[v] = k
        return vocab

    def get_unk_ix(self):
        vb_vocab = self.get_vb_vocab()
        return int(vb_vocab['unk'])

    def get_vb_weights(self):
        vb_weights = []
        total_vb = sum(self.vb_counts.values())
        vb_vocab = self.get_vb_vocab()
        vb_vocab = sorted(vb_vocab.items(), key=lambda item:int(item[1]))
        for each in vb_vocab:
            weight = total_vb / (int(self.vb_counts[each[0]]) * 10000)
            vb_weights.append(weight)
        return torch.from_numpy(np.asarray(vb_weights)).float()

    def get_seq_length(self):
        return self.seq_length

    def read_files(self):
        self.feats_fc = h5py.File(os.path.join(
            self.opt.input_fc_dir, 'feats_fc.h5'), 'r')
        self.feats_att = h5py.File(os.path.join(
            self.opt.input_att_dir, 'feats_att.h5'), 'r')

    def get_data(self, ix, split):
        img_path = os.path.join(self.imgpth, self.info['images'][ix]['file_path'])

        image = Image.open(img_path)
        h, w = image.size
        g = max(h, w)
        delta_w = g - h
        delta_h = g - w
        padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
        image = expand(image, padding)
        if split == 'train':
            image = train_augmentation(image)
        else:
            image = test_augmentation(image)
        return image, ix

    def __init__(self, opt):
        self.opt = opt
        self.batch_size = self.opt.batch_size
        self.imgpth = opt.input_img_path

        # load json file which contains additional information about dataset
        print('DataLoader loading json file: ', opt.input_json)
        self.info = json.load(open(self.opt.input_json))
        self.ix_to_word = self.info['ix_to_word']
        self.w2ix = self.word_to_ix()
        self.vocab_unk_ix = int(self.w2ix['UNK'])
        self.vocab_size = len(self.ix_to_word)
        print('vocab size is ', self.vocab_size)
        self.vb_ix_to_word = self.info['vb_ix_to_word']
        self.vb_counts = self.info['vb_counts']
        self.vb_vocab_size = len(self.vb_ix_to_word)
        self.noun_labels = self.info['noun_labels']
        print('vb vocab size is ', self.vb_vocab_size)

        # open the hdf5 file
        print('DataLoader loading h5 file: ', opt.input_label_h5)
        self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r',
                                       driver='core')

        # load in the sequence data
        seq_size = self.h5_label_file['labels'].shape
        self.seq_length = seq_size[1]
        print('max sequence length in data is', self.seq_length)
        # load the pointers in full to RAM (should be small enough)
        self.label_start_ix = self.h5_label_file['label_start_ix'][:]
        self.label_end_ix = self.h5_label_file['label_end_ix'][:]
        self.noun_mask = self.h5_label_file['noun_mask'][:]

        self.num_images = self.label_start_ix.shape[0]
        print('read %d image features' % (self.num_images))

        # separate out indexes for each of the provided splits
        self.split_ix = {'train': [], 'test': []}
        for ix in range(len(self.info['images'])):
            img = self.info['images'][ix]
            if img['split'] == 'train':
                self.split_ix['train'].append(ix)
            # elif img['split'] == 'val':
            #     self.split_ix['val'].append(ix)
            elif img['split'] == 'test':
                self.split_ix['test'].append(ix)
            elif opt.train_only == 0:  # restval
                self.split_ix['train'].append(ix)

        print('assigned %d images to split train' % len(self.split_ix['train']))
        # print('assigned %d images to split val' % len(self.split_ix['val']))
        print('assigned %d images to split test' % len(self.split_ix['test']))

        self.iterators = {'train': 0, 'test': 0}

        self._prefetch_process = {}  # The three prefetch process
        for split in self.iterators.keys():
            self._prefetch_process[split] = BlobFetcher(split,
                                                        self,
                                                        split == 'train')
            # Terminate the child process when the parent exists

        def cleanup():
            print('Terminating BlobFetcher')
            for split in self.iterators.keys():
                del self._prefetch_process[split]

        import atexit
        atexit.register(cleanup)

    def get_batch(self, split, batch_size=None):
        batch_size = batch_size or self.batch_size
        img_batch = np.zeros([batch_size, 3, 224, 224], dtype='float32')
        label_batch = np.zeros([batch_size, self.seq_length + 2], dtype='int')
        vb_label_batch = np.zeros([batch_size, self.vb_vocab_size],dtype='int')
        mask_batch = np.zeros([batch_size, self.seq_length + 2], dtype='float32')
        noun_batch = np.zeros([batch_size, self.vocab_size + 1], dtype='float32')
        noun_mask_batch = np.zeros([batch_size, 49], dtype='float32')
        masks = []
        wrapped = False
        infos = []
        gts = []

        for i in range(batch_size):
            img, ix, tmp_wrapped = self._prefetch_process[split].get()
            img_batch[i] = img

            ix1 = self.label_start_ix[ix]
            ix2 = self.label_end_ix[ix] - 1
            ncap = ix2 - ix1 + 1

            for x in range(ix1, ix2):
                if len(self.h5_label_file['vblabels'][x]) > 0:
                    ind = self.h5_label_file['vblabels'][x].nonzero()
                    for each in ind:
                        vb_label_batch[i][self.h5_label_file['vblabels'][x][each]] = 1

            assert ncap > 0, 'an image does not have any label'

            ixl = random.randint(ix1, ix2)
            label_batch[i, 1: self.seq_length + 1] = self.h5_label_file['labels'][ixl]
            nouns = self.noun_labels[ixl]['subject']
            noun_mask_batch[i] = self.noun_mask[ixl]

            l_nouns = len(nouns)
            if l_nouns > 0:
                seed = random.randint(0, l_nouns - 1)
                masks.append(nouns[seed])
                for each in nouns[seed]:
                    noun_batch[i][each - 1] = 1
            else:
                noun_batch[i][self.vocab_unk_ix - 1] = 1
            # vb_labels = ((self.h5_label_file['vblabels'][ixl] >= 1) * self.h5_label_file['vblabels'][ixl]).nonzero()[0]
            # if len(self.h5_label_file['vblabels'][ixl]) >= 1:
            #     for each in self.h5_label_file['vblabels'][ixl]:
            #         vb_label_batch[i][each - 1] =1
            # if len(vb_labels) == 0:
            #     vb_label_batch[i] = self.get_unk_ix() - 1
            # else:
            #     vb_label_batch[i] = self.h5_label_file['vblabels'][ixl][choice(vb_labels)] - 1

            if tmp_wrapped:
                wrapped = True
            gts.append(self.h5_label_file['labels'][self.label_start_ix[ix]:self.label_end_ix[ix]])

            info_dict = {}
            info_dict['ix'] = ix
            info_dict['id'] = int(self.info['images'][ix]['id'])
            info_dict['file_path'] = self.info['images'][ix]['file_path']
            infos.append(info_dict)

        nonzeros = np.array(list(map(lambda x: (x != 0).sum() + 2, label_batch)))
        for ix, row in enumerate(mask_batch):
            row[:nonzeros[ix]] = 1

        data = {}
        # print('img_batch size: ', img_batch.size(0), img_batch.size(1))
        data['imgs'] = img_batch
        data['labels'] = label_batch
        data['vb_labels'] = vb_label_batch
        data['gts'] = gts
        data['masks'] = mask_batch
        data['noun_batch'] = noun_batch
        data['noun_mask_batch'] = noun_mask_batch
        data['bounds'] = {'it_pos_now': self.iterators[split],
                          'it_max': len(self.split_ix[split]),
                          'wrapped': wrapped}
        data['infos'] = infos
        data['masks_'] = masks

        return data

    # It's not coherent to make DataLoader a subclass of Dataset,
    # but essentially, we only need to implement the following to functions,
    # so that the torch.utils.data.DataLoader can load the data according
    # the index. However, it's minimum change to switch to pytorch data loading
    def __getitem__(self, index):
        ix = index
        return self.get_data(ix, split='train')

    def __len__(self):
        return len(self.info['images'])


class BlobFetcher():
    def __init__(self, split, dataloader, is_shuffle=False):
        self.split = split
        self.dataloader = dataloader
        self.is_shuffle = is_shuffle

    def reset(self):
        sampler = self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]
        self.split_loader = iter(
            data.DataLoader(dataset=self.dataloader,
                            batch_size=1,
                            sampler=sampler,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=multiprocessing.cpu_count(),
                            collate_fn=lambda x: x[0]))

    def _get_next_minibatch_inds(self):
        max_index = len(self.dataloader.split_ix[self.split])
        wrapped = False

        ri = self.dataloader.iterators[self.split]
        ix = self.dataloader.split_ix[self.split][ri]

        ri_next = ri + 1
        if ri_next >= max_index:
            ri_next = 0
            if self.is_shuffle:
                random.shuffle(self.dataloader.split_ix[self.split])
            wrapped = True
        self.dataloader.iterators[self.split] = ri_next

        return ix, wrapped

    def get(self):
        if not hasattr(self, 'split_loader'):
            self.reset()
        ix, wrapped = self._get_next_minibatch_inds()
        tmp = self.split_loader.next()
        if wrapped:
            self.reset()
        assert tmp[1] == ix, 'ix not equal'

        return tmp + [wrapped]

Are the batch sizes of train and test mode same?
Are you using with torch.no_grad() while testing?

Yes, batch size of train and test mode same. I do not use torch.no_grad(), because my torch is 0.3.0.

you could set volatile=True for data and target, if the pytorch version < 0.4.

data = Variable(data, volatile=True)
target = Variable(target, volatile=True)

Thanks, it’s work now!