Hi
I’m training network with two fully-connected layers and sparse variational dropout, using MNIST data. I’ve encountered terrible memory leak: after 100 epochs, more than 200GB RAM is used. Problem is specific for Ubuntu. On Mac and Windows 8GB RAM was more than enough. As memory profiler shows, the problem is somewhere in kl_reg method of LinearSVDO class. Here is code and profiling results:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from logger import Logger
from torch.nn import Parameter
from torchvision import datasets, transforms
from tqdm import trange, tqdm
from memory_profiler import profile
# Load a dataset
def get_mnist(batch_size):
trsnform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=trsnform), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, download=True,
transform=trsnform), batch_size=batch_size, shuffle=True)
return train_loader, test_loader
class LinearSVDO(nn.Module):
# Хардкодим параметры здесь для читаемости
shift = 1e-8
k1, k2, k3 = 0.63576, 1.8732, 1.48695
log_alpha_lower = -10.0
log_alpha_upper = 10.0
def __init__(self, in_features, out_features, threshold, bias=True):
super(LinearSVDO, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.threshold = threshold
self.W = Parameter(torch.Tensor(out_features, in_features))
###########################################################
######## You Code should be here ##########
# Create a Parameter to store log sigma
self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
###########################################################
self.bias = Parameter(torch.Tensor(1, out_features))
self.reset_parameters()
def reset_parameters(self):
self.bias.data.zero_()
self.W.data.normal_(0, 0.02)
self.log_sigma.data.fill_(-5)
def forward(self, x: torch.Tensor):
###########################################################
######## You Code should be here ##########
if self.training:
lrt_mean = F.linear(x, self.W) + self.bias # Compute activation's mean e.g x.dot(W) + b
temp = F.linear(x.pow(2), torch.exp(self.log_sigma * 2.0))
lrt_std = torch.sqrt(temp + self.shift) # Compute activation's var e.g sqrt((x*x).dot(sigma * sigma) + 1e-8)
eps = torch.normal(torch.FloatTensor([0.]).expand(lrt_std.size()),
torch.FloatTensor([1.]).expand(lrt_std.size())) # sample random noise
res = lrt_mean + lrt_std * eps
return res
######## If not training ##########
self.log_alpha = 2.0 * self.log_sigma - 2.0 * torch.log(self.shift + torch.abs(self.W)) # Evale log alpha as a function(log_sigma, W)
self.log_alpha = torch.clamp(self.log_alpha, self.log_alpha_lower, self.log_alpha_upper)# Clip log alpha to be in [-10, 10] for numerical stability
W = self.W * (self.log_alpha < 3.0).type(torch.FloatTensor) # Prune out redundant wights e.g. W * mask(log_alpha < 3)
return F.linear(x, W) + self.bias
###########################################################
@profile
def kl_reg(self):
###########################################################
######## You Code should be here ##########
######## Eval Approximation of KL Divergence ##########
# use torch.log1p for numerical stability
log_alpha = 2.0 * self.log_sigma - 2.0 * torch.log(torch.abs(self.W) + self.shift) # Evale log alpha as a function(log_sigma, W)
log_alpha = torch.clamp(log_alpha, self.log_alpha_lower, self.log_alpha_upper) # Clip log alpha to be in [-10, 10] for numerical suability
KL1 = self.k1 * torch.sigmoid(self.k2 + self.k3 * log_alpha)
KL2 = - 0.5 * torch.log1p(torch.exp(-log_alpha))
KL = KL1 + KL2
return -torch.sum(KL)
######## Return a KL divergence, a Tensor 1x1 ##########
###########################################################
# Define a simple 2 layer Network
class Net(nn.Module):
def __init__(self, threshold):
super(Net, self).__init__()
self.fc1 = LinearSVDO(28*28, 300, threshold)
self.fc2 = LinearSVDO(300, 10, threshold)
self.threshold = threshold
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.log_softmax(self.fc2(x), dim=1)
return x
# Define a new Loss Function -- ELBO
class ELBO(nn.Module):
def __init__(self, net, train_size):
super(ELBO, self).__init__()
self.train_size = train_size
self.net = net
def forward(self, input, target, kl_weight=1.0):
assert not target.requires_grad
###
kl = torch.Tensor([0.0])
###
for module in self.net.children():
if hasattr(module, 'kl_reg'):
kl = kl + module.kl_reg()
###########################################################
######## You Code should be here ##########
# Compute Stochastic Gradient Variational Lower Bound
# It is a sum of cross-entropy (Data term) and KL-divergence (Regularizer)
# Do not forget to scale up Data term to N/M,
# where N is a size of the dataset and M is a size of minibatch
# Делить на размер батча не нужно, в функции F.cross_entropy и так по умолчанию берется среднее
ELBO = F.cross_entropy(input, target) * self.train_size + kl_weight * kl
return ELBO # a Tensor 1x1
###########################################################
def run():
model = Net(threshold=3)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)
fmt = {'tr_los': '3.1e', 'te_loss': '3.1e', 'sp_0': '.3f', 'sp_1': '.3f', 'lr': '3.1e', 'kl': '.2f'}
logger = Logger('sparse_vd', fmt=fmt)
train_loader, test_loader = get_mnist(batch_size=100)
elbo = ELBO(model, len(train_loader.dataset))
kl_weight = 0.02
epochs = 5
for epoch in range(1, epochs + 1):
scheduler.step()
model.train()
train_loss, train_acc = 0, 0
kl_weight = min(kl_weight+0.02, 1)
logger.add_scalar(epoch, 'kl', kl_weight)
logger.add_scalar(epoch, 'lr', scheduler.get_lr()[0])
for batch_idx, (data, target) in enumerate(train_loader):
data = data.view(-1, 28*28)
optimizer.zero_grad()
output = model(data)
pred = output.data.max(1)[1]
loss = elbo(output, target, kl_weight)
loss.backward()
optimizer.step()
train_loss += loss
train_acc += np.sum(pred.numpy() == target.data.numpy())
logger.add_scalar(epoch, 'tr_los', train_loss / len(train_loader.dataset))
logger.add_scalar(epoch, 'tr_acc', train_acc / len(train_loader.dataset) * 100)
model.eval()
test_loss, test_acc = 0, 0
for batch_idx, (data, target) in enumerate(test_loader):
data = data.view(-1, 28*28)
output = model(data)
test_loss += float(elbo(output, target, kl_weight))
pred = output.data.max(1)[1]
test_acc += np.sum(pred.numpy() == target.data.numpy())
logger.add_scalar(epoch, 'te_loss', test_loss / len(test_loader.dataset))
logger.add_scalar(epoch, 'te_acc', test_acc / len(test_loader.dataset) * 100)
for i, c in enumerate(model.children()):
if hasattr(c, 'kl_reg'):
logger.add_scalar(epoch, 'sp_%s' % i, (c.log_alpha.data.numpy() > model.threshold).mean())
logger.iter_info()
Line # Mem usage Increment Line Contents
================================================
129 185.0 MiB 185.0 MiB @profile
130 def kl_reg(self):
131 ###########################################################
132 ######## You Code should be here ##########
133 ######## Eval Approximation of KL Divergence ##########
134 # use torch.log1p for numerical stability
135 189.1 MiB 4.1 MiB log_alpha = 2.0 * self.log_sigma - 2.0 * torch.log(torch.abs(self.W) + self.shift) # Evale log alpha as a function(log_sigma, W)
136 190.2 MiB 1.0 MiB log_alpha = torch.clamp(log_alpha, self.log_alpha_lower, self.log_alpha_upper) # Clip log alpha to be in [-10, 10] for numerical suability
137 194.7 MiB 4.5 MiB KL1 = self.k1 * torch.sigmoid(self.k2 + self.k3 * log_alpha)
138 197.5 MiB 2.8 MiB KL2 = - 0.5 * torch.log1p(torch.exp(-log_alpha))
139 199.2 MiB 1.7 MiB KL = KL1 + KL2
140 199.2 MiB 0.0 MiB return -torch.sum(KL)