Out of memory after 1 epoch

Hi guys, I saw all the posts about out of memory and tried a lot of things but nothing works… I am using pytorch 1 and this is my code:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Fine-tune the fc layer only for bilinear CNN.

Usage:
    CUDA_VISIBLE_DEVICES=0,1,2,3 ./src/bilinear_cnn_fc.py --base_lr 0.05 \
        --batch_size 64 --epochs 100 --weight_decay 5e-4
"""


import os

import torch
import torchvision

from torch.utils.data import DataLoader, Dataset
import data_generator as whales_generator
import albumentations

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.set_default_dtype(torch.float32)
torch.set_default_tensor_type(torch.FloatTensor)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = True

__all__ = ['BCNN', 'BCNNManager']
__author__ = 'Hao Zhang'
__copyright__ = '2018 LAMDA'
__date__ = '2018-01-09'
__email__ = 'zhangh0214@gmail.com'
__license__ = 'CC BY-SA 3.0'
__status__ = 'Development'
__updated__ = '2018-01-13'
__version__ = '1.2'


class BCNN(torch.nn.Module):
    """B-CNN for WHALES.

    The B-CNN model is illustrated as follows.
    conv1^2 (64) -> pool1 -> conv2^2 (128) -> pool2 -> conv3^3 (256) -> pool3
    -> conv4^3 (512) -> pool4 -> conv5^3 (512) -> bilinear pooling
    -> sqrt-normalize -> L2-normalize -> fc (200).
    The network accepts a 3*448*448 input, and the pool5 activation has shape
    512*28*28 since we down-sample 5 times.

    Attributes:
        features, torch.nn.Module: Convolution and pooling layers.
        fc, torch.nn.Module: 200.
    """
    def __init__(self):
        """Declare all needed layers."""
        torch.nn.Module.__init__(self)
        # Convolution and pooling layers of VGG-16.
        self.features = torchvision.models.vgg16(pretrained=True).features
        self.features = torch.nn.Sequential(*list(self.features.children())[0:16])  # Hay jupyter notebook para mirar
        # Linear classifier.
        self.fc = torch.nn.Linear(256**2, 5004)

        # Freeze all previous layers.
        for param in self.features.parameters():
            param.requires_grad = False
        # Initialize the fc layers.
        torch.nn.init.kaiming_normal_(self.fc.weight.data)
        if self.fc.bias is not None:
            torch.nn.init.constant_(self.fc.bias.data, val=0)

    def forward(self, X):
        """Forward pass of the network.

        Args:
            X, torch.autograd.Variable of shape N*3*448*448.

        Returns:
            Score, torch.autograd.Variable of shape N*200.
        """
        N = X.size()[0]
        assert X.size() == (N, 3, 224, 224)
        X = self.features(X)
        assert X.size() == (N, 256, 56, 56)
        X = X.view(N, 256, 56**2)
        X = torch.bmm(X, torch.transpose(X, 1, 2)) / (56**2)  # Bilinear
        assert X.size() == (N, 256, 256)
        X = X.view(N, 256**2)
        X = torch.sqrt(X + 1e-5)
        X = torch.nn.functional.normalize(X)
        X = self.fc(X)
        #assert X.size() == (N, 200)
        return X


class BCNNManager(object):
    """Manager class to train bilinear CNN.

    Attributes:
        _options: Hyperparameters.
        _path: Useful paths.
        _net: Bilinear CNN.
        _criterion: Cross-entropy loss.
        _optimizer: SGD with momentum.
        _scheduler: Reduce learning rate by a fator of 0.1 when plateau.
        _train_loader: Training data.
        _test_loader: Testing data.
    """
    def __init__(self, options, path):
        """Prepare the network, criterion, solver, and data.

        Args:
            options, dict: Hyperparameters.
        """
        print('Prepare the network and data.')
        self._options = options
        self._path = path
        # Network.
        self._net = torch.nn.DataParallel(BCNN()).cuda()
        print(self._net)
        # Criterion.
        self._criterion = torch.nn.CrossEntropyLoss().cuda()
        # Solver.
        self._optimizer = torch.optim.SGD(self._net.module.fc.parameters(), lr=self._options['base_lr'],
                                        momentum=0.9, weight_decay=self._options['weight_decay'])

        self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self._optimizer, mode='max', factor=0.1, patience=3, verbose=True,
            threshold=1e-4)

        data_transforms_train = albumentations.Compose([
            albumentations.Resize(224 ,224),
            albumentations.HorizontalFlip(),
            albumentations.Blur(blur_limit=2),
            albumentations.ShiftScaleRotate(rotate_limit=7, scale_limit=0.085),
            albumentations.RandomBrightness(limit=0.12),
            albumentations.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))
        ])

        train_dataset = whales_generator.WhaleDataset(
            datatype='train',
            transform=data_transforms_train,
            normalization="", # No porque ya lo normalizamos con la transformacion
            color=True, # Necasario para modelo preentrenado
            use_one=False, # Queremos todas las ballenitas!
            humpback_only=True # Solo las colitas
        )

        self._train_loader = DataLoader(train_dataset, batch_size=self._options['batch_size'], pin_memory=False, shuffle=True)


        """
        test_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(size=448),
            torchvision.transforms.CenterCrop(size=448),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                             std=(0.229, 0.224, 0.225))
        ])

        test_data = cub200.CUB200(
            root=self._path['cub200'], train=False, download=True,
            transform=test_transforms)

        self._test_loader = torch.utils.data.DataLoader(
            test_data, batch_size=16,
            shuffle=False, num_workers=4, pin_memory=True)
        """

    def train(self):
        """Train the network."""
        print('Training.')
        best_acc = 0.0
        best_epoch = None
        print('Epoch\tTrain loss\tTrain acc\tTest acc')
        for t in range(self._options['epochs']):
            epoch_loss = 0
            num_correct = 0
            num_total = 0
            total_batches = 0
            for X, y in self._train_loader:
                # Data.
                
                # Clear the existing gradients.
                X = X.cuda()
                y = y.cuda()

                # Forward pass.
                score = self._net(X)
                loss = self._criterion(score, y.long())

                with torch.no_grad():
                    epoch_loss += loss.item()
                    # Prediction.
                    prediction = torch.argmax(score, dim=1)
                    num_total += y.size(0)
                    num_correct += torch.sum(prediction == y.long()).item()

                # Backward pass.
                self._optimizer.zero_grad()
                loss.backward()
                self._optimizer.step()
                
                total_batches+=1
                del X, y, score, loss, prediction

            train_acc = 100 * num_correct / num_total
            #test_acc = self._accuracy(self._test_loader)
            self._scheduler.step(test_acc)
            if test_acc > best_acc:
                best_acc = test_acc
                best_epoch = t + 1
                print('*', end='')
                # Save model onto disk.
                torch.save(self._net.state_dict(),
                           os.path.join(self._path['model'],
                                        'vgg_16_epoch_%d.pth' % (t + 1)))
            print('%d\t%4.3f\t\t%4.2f%%\t\t%4.2f%%' %
                  (t+1, epoch_loss / total_batches, train_acc, test_acc))
        print('Best at epoch %d, test accuaray %f' % (best_epoch, best_acc))

    def _accuracy(self, data_loader):
        """Compute the train/test accuracy.

        Args:
            data_loader: Train/Test DataLoader.

        Returns:
            Train/Test accuracy in percentage.
        """
        self._net.train(False)
        num_correct = 0
        num_total = 0
        for X, y in data_loader:
            # Data.
            X = torch.autograd.Variable(X.cuda())
            y = torch.autograd.Variable(y.cuda(async=True))

            # Prediction.
            score = self._net(X)
            _, prediction = torch.max(score.data, 1)
            num_total += y.size(0)
            num_correct += torch.sum(prediction == y.logn().data).item()
        self._net.train(True)  # Set the model to training phase
        return 100 * num_correct / num_total

    def getStat(self):
        """Get the mean and std value for a certain dataset."""
        print('Compute mean and variance for training data.')

        data_transforms_train = albumentations.Compose([
            albumentations.Resize(448, 448)
        ])

        train_data = whales_generator.WhaleDataset(
            datatype='train',
            transform=data_transforms_train,
            normalization=norm_data,
            color=color,
            use_one=use_one,
            humpback_only=True
        )

        train_loader = DataLoader(train_data, batch_size=1, pin_memory=True, shuffle=True)
        
        mean = torch.zeros(3)
        std = torch.zeros(3)
        for X, _ in train_loader:
            for d in range(3):
                mean[d] += X[:, d, :, :].mean()
                std[d] += X[:, d, :, :].std()
        mean.div_(len(train_data))
        std.div_(len(train_data))
        print(mean)
        print(std)


def main():
    """The main function."""
    import argparse
    parser = argparse.ArgumentParser(
        description='Train bilinear CNN on WHALES.')
    """
    parser.add_argument('--base_lr', dest='base_lr', type=float, required=True,
                        help='Base learning rate for training.')
    parser.add_argument('--batch_size', dest='batch_size', type=int,
                        required=True, help='Batch size.')
    parser.add_argument('--epochs', dest='epochs', type=int,
                        required=True, help='Epochs for training.')
    parser.add_argument('--weight_decay', dest='weight_decay', type=float,
                        required=True, help='Weight decay.')
    """
    parser.add_argument('--base_lr', dest='base_lr', type=float, required=False, default=0.1,
                        help='Base learning rate for training.')
    parser.add_argument('--batch_size', dest='batch_size', type=int, default=2,
                        required=False, help='Batch size.')
    parser.add_argument('--epochs', dest='epochs', type=int, default=10,
                        required=False, help='Epochs for training.')
    parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-8,
                        required=False, help='Weight decay.')
    args = parser.parse_args()
    if args.base_lr <= 0:
        raise AttributeError('--base_lr parameter must >0.')
    if args.batch_size <= 0:
        raise AttributeError('--batch_size parameter must >0.')
    if args.epochs < 0:
        raise AttributeError('--epochs parameter must >=0.')
    if args.weight_decay <= 0:
        raise AttributeError('--weight_decay parameter must >0.')

    options = {
        'base_lr': args.base_lr,
        'batch_size': args.batch_size,
        'epochs': args.epochs,
        'weight_decay': args.weight_decay,
    }

    project_root = os.popen('pwd').read().strip()
    path = {
        'model': os.path.join(project_root + "/mario/bilinear", 'checkpoints'),
    }
    for d in path:
        assert os.path.isdir(path[d])

    manager = BCNNManager(options, path)
    # manager.getStat()
    manager.train()


if __name__ == '__main__':
    main()

When starts, epoch 1 consumes 5Gb but on second epoch it runs out of memory at backward point :frowning: