Terrible memory leak while training bayesian network

Hi
I’m training network with two fully-connected layers and sparse variational dropout, using MNIST data. I’ve encountered terrible memory leak: after 100 epochs, more than 200GB RAM is used. Problem is specific for Ubuntu. On Mac and Windows 8GB RAM was more than enough. As memory profiler shows, the problem is somewhere in kl_reg method of LinearSVDO class. Here is code and profiling results:

import torch
import numpy as np

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from logger import Logger
from torch.nn import Parameter
from torchvision import datasets, transforms

from tqdm import trange, tqdm
from memory_profiler import profile

# Load a dataset
def get_mnist(batch_size):
    trsnform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
        transform=trsnform), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True,
        transform=trsnform), batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

class LinearSVDO(nn.Module):
    # Хардкодим параметры здесь для читаемости
    shift = 1e-8
    k1, k2, k3 = 0.63576, 1.8732, 1.48695
    log_alpha_lower = -10.0
    log_alpha_upper = 10.0


    def __init__(self, in_features, out_features, threshold, bias=True):
        super(LinearSVDO, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold = threshold

        self.W = Parameter(torch.Tensor(out_features, in_features))
        ###########################################################
        ########         You Code should be here         ##########
        # Create a Parameter to store log sigma
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
        ###########################################################
        self.bias = Parameter(torch.Tensor(1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        self.bias.data.zero_()
        self.W.data.normal_(0, 0.02)
        self.log_sigma.data.fill_(-5)

    def forward(self, x: torch.Tensor):
        ###########################################################
        ########         You Code should be here         ##########
        if self.training:
            lrt_mean = F.linear(x, self.W) + self.bias  # Compute activation's mean e.g x.dot(W) + b
            temp = F.linear(x.pow(2), torch.exp(self.log_sigma * 2.0))
            lrt_std = torch.sqrt(temp + self.shift)  # Compute activation's var e.g sqrt((x*x).dot(sigma * sigma) + 1e-8)
            eps = torch.normal(torch.FloatTensor([0.]).expand(lrt_std.size()),
                               torch.FloatTensor([1.]).expand(lrt_std.size()))  # sample random noise
            res = lrt_mean + lrt_std * eps
            return res

        ########         If not training        ##########
        self.log_alpha = 2.0 * self.log_sigma - 2.0 * torch.log(self.shift + torch.abs(self.W))  # Evale log alpha as a function(log_sigma, W)
        self.log_alpha = torch.clamp(self.log_alpha, self.log_alpha_lower, self.log_alpha_upper)# Clip log alpha to be in [-10, 10] for numerical stability
        W = self.W * (self.log_alpha < 3.0).type(torch.FloatTensor)  # Prune out redundant wights e.g. W * mask(log_alpha < 3)
        return F.linear(x, W) + self.bias
        ###########################################################
    
    @profile
    def kl_reg(self):
        ###########################################################
        ########         You Code should be here         ##########
        ########  Eval Approximation of KL Divergence    ##########
        # use torch.log1p for numerical stability
        log_alpha = 2.0 * self.log_sigma - 2.0 * torch.log(torch.abs(self.W) + self.shift) # Evale log alpha as a function(log_sigma, W)
        log_alpha = torch.clamp(log_alpha, self.log_alpha_lower, self.log_alpha_upper) # Clip log alpha to be in [-10, 10] for numerical suability
        KL1 = self.k1 * torch.sigmoid(self.k2 + self.k3 * log_alpha) 
        KL2 = - 0.5 * torch.log1p(torch.exp(-log_alpha))
        KL = KL1 + KL2
        return -torch.sum(KL)
        ########  Return a KL divergence, a Tensor 1x1   ##########
        ###########################################################

# Define a simple 2 layer Network
class Net(nn.Module):
    def __init__(self, threshold):
        super(Net, self).__init__()
        self.fc1 = LinearSVDO(28*28, 300, threshold)
        self.fc2 = LinearSVDO(300,  10, threshold)
        self.threshold = threshold

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

# Define a new Loss Function -- ELBO 
class ELBO(nn.Module):
    def __init__(self, net, train_size):
        super(ELBO, self).__init__()
        self.train_size = train_size
        self.net = net
    
    def forward(self, input, target, kl_weight=1.0):
        assert not target.requires_grad
        ###
        kl = torch.Tensor([0.0])
        ###
        for module in self.net.children():
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
        ###########################################################
        ########         You Code should be here         ##########    
        # Compute Stochastic Gradient Variational Lower Bound
        # It is a sum of cross-entropy (Data term) and KL-divergence (Regularizer)
        # Do not forget to scale up Data term to N/M,
        # where N is a size of the dataset and M is a size of minibatch
        
        # Делить на размер батча не нужно, в функции F.cross_entropy и так по умолчанию берется среднее
        ELBO = F.cross_entropy(input, target) * self.train_size + kl_weight * kl
        return ELBO # a Tensor 1x1 
        ###########################################################

def run():
    model = Net(threshold=3)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

    fmt = {'tr_los': '3.1e', 'te_loss': '3.1e', 'sp_0': '.3f', 'sp_1': '.3f', 'lr': '3.1e', 'kl': '.2f'}
    logger = Logger('sparse_vd', fmt=fmt)

    train_loader, test_loader = get_mnist(batch_size=100)
    elbo = ELBO(model, len(train_loader.dataset))
    kl_weight = 0.02
    epochs = 5

    for epoch in range(1, epochs + 1):
        scheduler.step()
        model.train()
        train_loss, train_acc = 0, 0 
        kl_weight = min(kl_weight+0.02, 1)
        logger.add_scalar(epoch, 'kl', kl_weight)
        logger.add_scalar(epoch, 'lr', scheduler.get_lr()[0])
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(-1, 28*28)
            optimizer.zero_grad()

            output = model(data)
            pred = output.data.max(1)[1] 
            loss = elbo(output, target, kl_weight)
            loss.backward()
            optimizer.step()

            train_loss += loss 
            train_acc += np.sum(pred.numpy() == target.data.numpy())

        logger.add_scalar(epoch, 'tr_los', train_loss / len(train_loader.dataset))
        logger.add_scalar(epoch, 'tr_acc', train_acc / len(train_loader.dataset) * 100)


        model.eval()
        test_loss, test_acc = 0, 0
        for batch_idx, (data, target) in enumerate(test_loader):
            data = data.view(-1, 28*28)
            output = model(data)
            test_loss += float(elbo(output, target, kl_weight))
            pred = output.data.max(1)[1] 
            test_acc += np.sum(pred.numpy() == target.data.numpy())

        logger.add_scalar(epoch, 'te_loss', test_loss / len(test_loader.dataset))
        logger.add_scalar(epoch, 'te_acc', test_acc / len(test_loader.dataset) * 100)

        for i, c in enumerate(model.children()):
            if hasattr(c, 'kl_reg'):
                logger.add_scalar(epoch, 'sp_%s' % i, (c.log_alpha.data.numpy() > model.threshold).mean())

        logger.iter_info()
Line #    Mem usage    Increment   Line Contents
================================================
   129    185.0 MiB    185.0 MiB       @profile
   130                                 def kl_reg(self):
   131                                     ###########################################################
   132                                     ########         You Code should be here         ##########
   133                                     ########  Eval Approximation of KL Divergence    ##########
   134                                     # use torch.log1p for numerical stability
   135    189.1 MiB      4.1 MiB           log_alpha = 2.0 * self.log_sigma - 2.0 * torch.log(torch.abs(self.W) + self.shift) # Evale log alpha as a function(log_sigma, W)
   136    190.2 MiB      1.0 MiB           log_alpha = torch.clamp(log_alpha, self.log_alpha_lower, self.log_alpha_upper) # Clip log alpha to be in [-10, 10] for numerical suability
   137    194.7 MiB      4.5 MiB           KL1 = self.k1 * torch.sigmoid(self.k2 + self.k3 * log_alpha) 
   138    197.5 MiB      2.8 MiB           KL2 = - 0.5 * torch.log1p(torch.exp(-log_alpha))
   139    199.2 MiB      1.7 MiB           KL = KL1 + KL2
   140    199.2 MiB      0.0 MiB           return -torch.sum(KL)

One reason for the memory growth is that you are storing the whole computation graph in train_loss by summing loss to it.
Since you only need train_loss to print the current loss, you should detach the computation graph from it and just store the loss value:

train_loss += loss.item()

It should cause a memory growth on other platforms as well, so there might be another issue.
Let me know, if that solves the problem.

1 Like

Thanks! That solved the problem.
Still not clear, why this particular leak is platform-specific.