I am trying to train a vgg model where each 30 iterations i use a prunning algorithm to prune one of the feature maps (filters) and we continue training. Problem is that, somehow during prunning the gpu memory goes higher and higher and after a while cuda goes out of memory. I am wondering if anyone can help me why the memory usage is going up or a way to figure out what are the tensors on the gpu. Thanks.
Here is the main training loop where i prune a feature map every 30 mini batch steps:
def train(self):
self.keep_ratio = 1
self.epoch = 1
self.already_prunned = False
self.optimiser = optim.SGD(self.net.parameters(), lr=self.initial_lr, momentum=self.momentum, weight_decay=0)
self.writer = SummaryWriter()
self.train_generator = iter(self.train_loader)
self.val_generator = iter(self.val_loader)
t_start = time.time()
running_loss = 0
running_accuracy = 0
for self.itr in range(self.max_training_steps):
print('\r{}'.format(self.itr), end='', flush=True)
self.net.train()
self.optimiser.zero_grad()
if self.itr%30 == 0 :
self.save_model()
self.Prunning = PrunningModel(self.net, self.train_loader, weight_decay_rate = self.weight_decay_rate, epoch= False, use_cuda = self.use_cuda)
self.net = self.Prunning.prune(1)
self.optimiser = optim.SGD(self.net.parameters(), lr=self.initial_lr, momentum=self.momentum, weight_decay=0)
self.optimiser.zero_grad()
self.count_params_after_prunning = sum(p.numel() for p in self.net.parameters() if p.requires_grad)
self.keep_ratio = self.count_params_after_prunning/self.count_params
print ("keep ratio after prunning is {}".format(self.keep_ratio))
self.already_prunned = True
try:
img, label = next(self.train_generator)
except StopIteration:
print("epoch {} training finished".format(self.epoch))
self.train_generator = iter(self.train_loader)
img, label = next(self.train_generator)
self.epoch = self.epoch +1
self.already_prunned = False
if isinstance(img, list):
img = img[0]
label = torch.squeeze(label)
img = img.to(self.device)
label = label.to(self.device)
self.img = img
self.label = label
output = self.net(img)
self.output = output
self.loss = F.nll_loss(output, label) +self.weight_decay_rate * torch.sum(torch.stack([torch.sum(v**2) for v in self.net.parameters()]))
_, preds = torch.max(output, 1)
accuracy = torch.sum(preds == label.data)
self.loss.backward()
self.optimiser.step()
#self.lr_scheduler.step()
running_loss += self.loss.item()
running_accuracy += accuracy.item()
and following is the code for my prunning methodology. For any feature map we calculate gradient with respect to its activation times the value of activation and take the absolute value and remove the ones that are lowest.
class FilterPrunner:
def __init__(self, model, use_cuda ,lamda = 1e-3):
self.model = model
self.model_name = str(self.model.__class__.__name__)
self.use_cuda = use_cuda
self.reset()
self.lamda = lamda
def reset(self):
self.filter_ranks = {}
def forward(self, x):
self.activations = []
self.gradients = []
self.grad_index = 0
self.activation_to_layer = {}
activation_index = 0
for layer, (name, module) in enumerate(self.model.features._modules.items()):
x = module(x)
if isinstance(module, nn.modules.conv.Conv2d):
x.register_hook(self.compute_rank)
self.activations.append(x)
self.activation_to_layer[activation_index] = layer
activation_index += 1
x = self.model.avgpool(x)
return F.log_softmax(self.model.classifier(x.view(x.size(0), -1)), dim=1)
def compute_rank(self, grad):
activation_index = len(self.activations) - self.grad_index - 1
activation = self.activations[activation_index]
#print(activation.shape)
taylor = activation * grad
# Get the average value for every filter,
# accross all the other dimensions
taylor = taylor.mean(dim=(0, 2, 3)).data.cpu()
if activation_index not in self.filter_ranks:
self.filter_ranks[activation_index] = \
torch.FloatTensor(activation.size(1)).zero_()
self.filter_ranks[activation_index] += taylor
self.grad_index += 1
def lowest_ranking_filters(self, num):
data = []
for i in sorted(self.filter_ranks.keys()):
for j in range(self.filter_ranks[i].size(0)):
data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j]))
return nsmallest(num, data, itemgetter(2))
def normalize_ranks_per_layer(self):
for i in self.filter_ranks:
v = torch.abs(self.filter_ranks[i]).cpu()
v = v / np.sqrt(torch.sum(v * v))
if self.flops_layer:
v = v
self.filter_ranks[i] = v
def get_prunning_plan(self, num_filters_to_prune):
filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune)
# After each of the k filters are prunned,
# the filter index of the next filters change since the model is smaller.
filters_to_prune_per_layer = {}
for (l, f, _) in filters_to_prune:
if l not in filters_to_prune_per_layer:
filters_to_prune_per_layer[l] = []
filters_to_prune_per_layer[l].append(f)
for l in filters_to_prune_per_layer:
filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l])
for i in range(len(filters_to_prune_per_layer[l])):
filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i
filters_to_prune = []
for l in filters_to_prune_per_layer:
for i in filters_to_prune_per_layer[l]:
filters_to_prune.append((l, i))
return filters_to_prune
class PrunningModel:
def __init__(self, model, train_loader, weight_decay_rate, epoch = False, use_cuda = True):
self.model = model
self.train_loader = train_loader
self.epoch = epoch #if use all the epoch to find the rank of filter or just one batch.
self.use_cuda = use_cuda
self.prunner = FilterPrunner(self.model, self.use_cuda)
self.weight_decay_rate = weight_decay_rate
def get_candidates_to_prune(self, num_filters_to_prune):
self.prunner.reset()
if self.epoch:
self.rank_filters_epoch()
else:
img, label = next(iter(self.train_loader))
if isinstance(img, list):
img = img[0]
if self.use_cuda:
img = img.cuda()
label = label.cuda()
self.rank_filters_batch(img, label)
self.prunner.normalize_ranks_per_layer()
return self.prunner.get_prunning_plan(num_filters_to_prune)
def total_num_filters(self): #calculates the total number of filters in the convolutional layers
filters = 0
for name, module in self.model.features._modules.items():
if isinstance(module, nn.modules.conv.Conv2d):
filters = filters + module.out_channels
return filters
def rank_filters_batch(self, batch, label):
torch.cuda.empty_cache()
if self.use_cuda:
batch = batch.cuda()
label = label.cuda()
self.model.zero_grad()
input = Variable(batch)
output = self.prunner.forward(input)
self.loss = F.nll_loss(output, label) +self.weight_decay_rate * torch.sum(torch.stack([torch.sum(v**2) for v in self.model.parameters()]))
self.loss.backward(retain_graph = False)
def rank_filters_epoch(self):
for i, (batch, label) in enumerate(self.train_loader):
img = batch[0]
self.rank_filters_batch(img, label)
def prune(self, num_filters_to_prune):
#Make sure all the layers are trainable
for param in self.model.features.parameters():
param.requires_grad = True
number_of_filters = self.total_num_filters()
print('Initial number of filters is {}'.format(number_of_filters))
print("Ranking filters.. ")
prune_targets = self.get_candidates_to_prune(num_filters_to_prune)
layers_prunned = {}
for layer_index, filter_index in prune_targets:
if layer_index not in layers_prunned:
layers_prunned[layer_index] = 0
layers_prunned[layer_index] = layers_prunned[layer_index] + 1
print("Layers that will be prunned", layers_prunned)
print("Prunning filters.. ")
model = self.model.cpu()
for layer_index, filter_index in prune_targets:
model = prune_conv_layer(model, layer_index, filter_index, use_cuda=self.use_cuda)
self.model = model
if self.use_cuda:
self.model = self.model.cuda()
new_number_of_filters = self.total_num_filters()
message = str(100 - 100*float(new_number_of_filters) / number_of_filters) + "%"
print("Percentage of prunned Filters", str(message))
print('Number of filters after pruning {}'.format(new_number_of_filters))
#torch.save(model.state_dict(), "model_prunned")
return self.model