How to check what adds to gpu memory?

I am trying to train a vgg model where each 30 iterations i use a prunning algorithm to prune one of the feature maps (filters) and we continue training. Problem is that, somehow during prunning the gpu memory goes higher and higher and after a while cuda goes out of memory. I am wondering if anyone can help me why the memory usage is going up or a way to figure out what are the tensors on the gpu. Thanks.

Here is the main training loop where i prune a feature map every 30 mini batch steps:

def train(self):
    self.keep_ratio = 1
    self.epoch = 1
    self.already_prunned = False
    self.optimiser = optim.SGD(, lr=self.initial_lr, momentum=self.momentum, weight_decay=0)
    self.writer = SummaryWriter()

    self.train_generator = iter(self.train_loader)
    self.val_generator = iter(self.val_loader)
    t_start = time.time()
    running_loss = 0
    running_accuracy = 0

    for self.itr in range(self.max_training_steps):
        print('\r{}'.format(self.itr), end='', flush=True)

        if self.itr%30 == 0 :
            self.Prunning = PrunningModel(, self.train_loader, weight_decay_rate = self.weight_decay_rate, epoch= False, use_cuda = self.use_cuda)
   = self.Prunning.prune(1)
            self.optimiser = optim.SGD(, lr=self.initial_lr, momentum=self.momentum, weight_decay=0)
            self.count_params_after_prunning = sum(p.numel() for p in if p.requires_grad)
            self.keep_ratio = self.count_params_after_prunning/self.count_params
            print ("keep ratio after prunning is {}".format(self.keep_ratio))
            self.already_prunned = True

            img, label = next(self.train_generator)
        except StopIteration:
            print("epoch {} training finished".format(self.epoch))  

            self.train_generator = iter(self.train_loader)
            img, label = next(self.train_generator)
            self.epoch = self.epoch +1
            self.already_prunned = False

        if isinstance(img, list):
            img = img[0]
        label = torch.squeeze(label)
        img =

        label =

        self.img = img
        self.label = label
        output =
        self.output = output

        self.loss = F.nll_loss(output, label) +self.weight_decay_rate * torch.sum(torch.stack([torch.sum(v**2) for v in]))
        _, preds = torch.max(output, 1)
        accuracy = torch.sum(preds ==

        running_loss += self.loss.item()
        running_accuracy += accuracy.item() 

and following is the code for my prunning methodology. For any feature map we calculate gradient with respect to its activation times the value of activation and take the absolute value and remove the ones that are lowest.

class FilterPrunner:
def __init__(self, model, use_cuda ,lamda = 1e-3):
    self.model = model
    self.model_name = str(self.model.__class__.__name__)
    self.use_cuda = use_cuda
    self.lamda = lamda

def reset(self):
    self.filter_ranks = {}

def forward(self, x):
    self.activations = []
    self.gradients = []
    self.grad_index = 0
    self.activation_to_layer = {}

    activation_index = 0
    for layer, (name, module) in enumerate(self.model.features._modules.items()):
        x = module(x)
        if isinstance(module, nn.modules.conv.Conv2d):
            self.activation_to_layer[activation_index] = layer
            activation_index += 1
    x = self.model.avgpool(x)
    return F.log_softmax(self.model.classifier(x.view(x.size(0), -1)), dim=1)

def compute_rank(self, grad):
    activation_index = len(self.activations) - self.grad_index - 1
    activation = self.activations[activation_index]
    taylor = activation * grad
    # Get the average value for every filter, 
    # accross all the other dimensions
    taylor = taylor.mean(dim=(0, 2, 3)).data.cpu()

    if activation_index not in self.filter_ranks:
        self.filter_ranks[activation_index] = \

    self.filter_ranks[activation_index] += taylor
    self.grad_index += 1

def lowest_ranking_filters(self, num):
    data = []
    for i in sorted(self.filter_ranks.keys()):
        for j in range(self.filter_ranks[i].size(0)):
            data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j]))

    return nsmallest(num, data, itemgetter(2))

def normalize_ranks_per_layer(self):
    for i in self.filter_ranks:
        v = torch.abs(self.filter_ranks[i]).cpu()
        v = v / np.sqrt(torch.sum(v * v))
        if self.flops_layer:
            v = v 
        self.filter_ranks[i] = v

def get_prunning_plan(self, num_filters_to_prune):
    filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune)

    # After each of the k filters are prunned,
    # the filter index of the next filters change since the model is smaller.
    filters_to_prune_per_layer = {}
    for (l, f, _) in filters_to_prune:
        if l not in filters_to_prune_per_layer:
            filters_to_prune_per_layer[l] = []

    for l in filters_to_prune_per_layer:
        filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l])
        for i in range(len(filters_to_prune_per_layer[l])):
            filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i

    filters_to_prune = []
    for l in filters_to_prune_per_layer:
        for i in filters_to_prune_per_layer[l]:
            filters_to_prune.append((l, i))

    return filters_to_prune             

class PrunningModel:
def __init__(self, model, train_loader, weight_decay_rate, epoch = False, use_cuda = True):
    self.model = model
    self.train_loader = train_loader
    self.epoch = epoch #if use all the epoch to find the rank of filter or just one batch.
    self.use_cuda = use_cuda
    self.prunner = FilterPrunner(self.model, self.use_cuda)
    self.weight_decay_rate = weight_decay_rate

def get_candidates_to_prune(self, num_filters_to_prune):
    if self.epoch:
        img, label = next(iter(self.train_loader))
        if isinstance(img, list):
            img = img[0]
        if self.use_cuda:
            img = img.cuda()
            label = label.cuda()
        self.rank_filters_batch(img, label)
    return self.prunner.get_prunning_plan(num_filters_to_prune)
def total_num_filters(self): #calculates the total number of filters in the convolutional layers
    filters = 0
    for name, module in self.model.features._modules.items():
        if isinstance(module, nn.modules.conv.Conv2d):
            filters = filters + module.out_channels
    return filters

def rank_filters_batch(self, batch, label):
    if self.use_cuda:
        batch = batch.cuda()
        label = label.cuda()

    input = Variable(batch)
    output = self.prunner.forward(input)
    self.loss = F.nll_loss(output, label) +self.weight_decay_rate * torch.sum(torch.stack([torch.sum(v**2) for v in self.model.parameters()]))
    self.loss.backward(retain_graph = False)

def rank_filters_epoch(self):
    for i, (batch, label) in enumerate(self.train_loader):
        img = batch[0]
        self.rank_filters_batch(img, label)

def prune(self, num_filters_to_prune):
    #Make sure all the layers are trainable
    for param in self.model.features.parameters():
        param.requires_grad = True

    number_of_filters = self.total_num_filters()
    print('Initial number of filters is {}'.format(number_of_filters))
    print("Ranking filters.. ")
    prune_targets = self.get_candidates_to_prune(num_filters_to_prune)
    layers_prunned = {}
    for layer_index, filter_index in prune_targets:
        if layer_index not in layers_prunned:
            layers_prunned[layer_index] = 0
        layers_prunned[layer_index] = layers_prunned[layer_index] + 1 

    print("Layers that will be prunned", layers_prunned)
    print("Prunning filters.. ")
    model = self.model.cpu()
    for layer_index, filter_index in prune_targets:
        model = prune_conv_layer(model, layer_index, filter_index, use_cuda=self.use_cuda)
    self.model = model
    if self.use_cuda:
        self.model = self.model.cuda()

    new_number_of_filters = self.total_num_filters()
    message = str(100 - 100*float(new_number_of_filters) / number_of_filters) + "%"
    print("Percentage of prunned Filters", str(message))
    print('Number of filters after pruning {}'.format(new_number_of_filters)), "model_prunned")
    return self.model