Memory Increases double while training

I have a problem. Memory increase almost double during training.
Here is my code. Please give me a solution about memory increases problems.
My code starts following function call order (ssl_train -> _train_epoches -> _train_batch). It seems to be the memory increasing problem where every time function named ‘train_epoches’ finishes. Thank you. @ptrblck_de, Could you help me?

def _train_batch(self, input_variable, input_length, target_variable, model):
    logits, attn = model(input_variable, input_length)
    loss = self.criterion(logits, target_variable)
    self.optimizer.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    return loss



def _train_epoches(self, data, model, n_epochs, dev_data=None, test_data=None):
    labeled_dataset = torchtext.data.Dataset(data, fields=[('text', self.TEXT_field), ('label', self.LABEL_field)])
    label_batch_iter = torchtext.data.BucketIterator(dataset=labeled_dataset, batch_size=128,
                                                     sort_key=lambda x:len(x.text), sort_within_batch=True,
                                                     device=self.device, repeat=False, shuffle=True)
    log = self.logger
    
    early_stopping = EarlyStopping(patience = 3, verbose=True)
    best_accuracy = 0
    
    for epoch in range(0, n_epochs):
        model.train()
        loss_total = 0
        step = 0
        for batch in label_batch_iter:
            input_variables, input_lengths = batch.text
            target_variables = batch.label
            loss = self._train_batch(input_variables, input_lengths.tolist(), target_variables, model)
            loss_total += loss.item()
            step +=1
            del loss, batch
            
        epoch_loss_avg = loss_total / step
        log_msg = "Finished epoch %d: SSL Train %s: %.4f" % (epoch , 'Cross_Entropy', epoch_loss_avg)
        with torch.no_grad():
            if dev_data is not None:
                model.eval()
                dev_loss, dev_acc = self.evaluator.evaluate(model, dev_data)
                self.dev_acc = dev_acc
                log_msg +=  ", Dev %s: %.4f, Accuracy: %.4f" % ('Cross_Entropy', dev_loss, dev_acc)
                log.info(log_msg)
                early_stopping(dev_loss, model, self.optimizer, epoch, step, self.input_vocab, self.expt_dir)
                print('dev_ early stopping', early_stopping.counter)
                if self.dev_acc > best_accuracy:
                    best_accuracy = self.dev_acc
                    Checkpoint(model=model, optimizer= self.optimizer, 
                               epoch=epoch, step=step, input_vocab=self.input_vocab).save(self.expt_dir +'/best_accuracy')
                    print('*'*100)
                    print('SAVE MODEL (BEST DEV ACC)')

            if test_data is not None:
                model.eval()
                test_loss, accuracy = self.evaluator.evaluate(model, test_data)
                log_msg +=  ", Test %s: %.4f, Accuracy: %.4f" % ('Cross_Entropy', test_loss, accuracy)
                log.info(log_msg)

            if early_stopping.early_stop:
                print("Early Stopping")
                checkpoint = Checkpoint.get_latest_checkpoint(self.expt_dir + '/best_accuracy')
                checkpoint = Checkpoint.load(checkpoint)
                model = checkpoint.model
                # config 
                optimizer = checkpoint.optimizer
                resume_optim  = checkpoint.optimizer.optimizer
                defaults = resume_optim.param_groups[0]
                defaults.pop('params', None)
                defaults.pop('initial_lr', None)
                optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults)
                self.optimizer = optimizer
                loss, accuracy = self.evaluator.evaluate(model, test_data)
                print('LOAD BEST ACCURACY MODEL ::: loss > {} accuracy{}'.format(loss, accuracy))
                break
    return model
            
            
def ssl_train(self, num_epochs=30, dev_data=None, test_data=None, methods=None, reverse_augment=False):
    log = self.logger
    self.device = torch.device('cuda:0') if torch.cuda.is_available() else -1
    model = self.model
    
    for epoch in range(num_epochs):
        print('INFO. # current epoch:', epoch)

        # load updated unlabeled examples dataset
        unlabeled_dataset = torchtext.data.Dataset(self.unlabeled_examples,
                                                   fields = [('text', self.TEXT_field), ('label', self.LABEL_field)])
        unlabel_batch_iter = torchtext.data.BucketIterator(dataset=unlabeled_dataset, batch_size=128,
                                                           sort_key=lambda x:len(x.text),  sort_within_batch=True,
                                                           device=self.device, repeat=False, shuffle=False, train=False)
        # predict unlabeled dataset
        labeled_samples = []
        
        # model evaluation
        model.eval()
        reverse_batch_iter = None
        
        with torch.no_grad():
            model.eval()
            for idx, unlabel_batch in enumerate(unlabel_batch_iter):
                str_list = itos(self.input_vocab, unlabel_batch.text[0])
                str_list = [sent.replace('<pad>','').strip() for sent in str_list]
                
                input_var = unlabel_batch.text[0]
                input_len = unlabel_batch.text[1]
                target_label = unlabel_batch.label
                logits, attn = model(input_var, input_len.tolist())
                if methods == 'self-training':
                    labeled_dataset = self.self_training_labeling(input_var, str_list, logits, attn, target_label)
                elif methods == 'pseudo-label':
                    labeled_dataset = self.pseudo_labeling(input_var, str_list, logits, attn, target_label)
                labeled_samples.extend(labeled_dataset)
        labeled_samples = list(set(labeled_samples))
        
        if reverse_augment is False:
            labeled_samples = self.balancing_labeled_result(labeled_samples)
            print('INFO. # num of new balacned dataset from pseudo label result', len(labeled_samples))
        else:
            self.save_labeled_result(self.outputs_dir+'/pseudo_label_epoch_{}.txt'.format(str(epoch)), labeled_samples)
            reversed_examples = self.gen_reverse_examples(labeled_samples)
        self.update_dataset(labeled_samples)
        
        if reverse_augment is True:
            self.labeled_examples.extend(reversed_examples)
            print('ADD REVERSED_EXAMPLES {} -> {}'.format(len(reversed_examples),len(self.labeled_examples)))

        model = self._train_epoches(self.labeled_examples, model, 30, dev_data, test_data)
        
        if self.ssl_early_stopping_patience == self.ssl_early_stopping or len(self.unlabeled_examples) == 0:
            print('SSL EARLY STOPPING EPOCH {}'.format(epoch))
            break

It seems you are returning a model when you call self._train_epoches. Assuming that this function applies your optimizer to the gradients of the parameters you want to train, given your loss, maybe you don’t need to return anything since model is a reference, simply remove the return value of self._train_epoches(…) and update the model from within that function? Probably what is happening is you are adding a new model’s worth of memory each time this function is called, i.e. making a copy of the model and not freeing the previous copy’s memory. Just a guess.

Also if you are going to call del on variables I would follow it up with gc.collect() since sometimes I find this doesn’t really clear the memory, but rather just deletes the variables name-binding.

I wanna reassign previous model with improved model. So I assign model = improved_model. Do you think that assigning model is causing memory problem? Then, could you explain more detail reason? Thank you.

Without reading too deeply into your code, I suspect that this reassignment is the problem (just a guess).
What I am trying to say is suppose you have code like

def areTorchModulesEqual(module1, module2):
    for index, (p1, p2) in enumerate(zip(module1.parameters(), module2.parameters())):
        if p1.data.ne(p2.data).sum() > 0:
            return False
    return True

def train_model(model, optimizer, data_iterator, num_epochs, loss_function):
      for epoch in range(num_epochs):
          for batch in data_iterator:
                   training_data = batch[0]
                   targets = batch[1]
                   optimizer.zero_grad()
                   output = model(training_data)
                   loss = loss_function(output, targets)
                   loss.backward()
                   optimizer.step() 
       // here you don't need to return anything, the model is being modified
modelA = getInitialModel()
modelB = getInitialModel()
modelB.load_state_dict(modelA.state_dict())
train_model(modelA, optimizer, data_iterator, num_epochs, loss_function)
print(areTorchModulesEqual(moduleA, moduleB) # should return false

To see where exactly the GPU usage is increasing, in case my theory is wrong, you can use these:

def printGPUInfo(verbose=False):
    print("total memory available: " + str(torch.cuda.get_device_properties(0).total_memory * 1e-9))
    print("Currently reserved: " + str(torch.cuda.memory_reserved(0) * 1e-9))
    print("Currently allocated: " + str(torch.cuda.memory_allocated(0) * 1e-9))
    print("Max reserved " + str(torch.cuda.max_memory_reserved(0) * 1e-9))
    print("Max allocated " + str(torch.cuda.max_memory_reserved(0) * 1e-9))
    print("\n")
    if verbose:
        print(torch.cuda.memory_summary(device=None, abbreviated=False))

Call the above function before and after your call to self._train_epoches, where I go line by line
in various trouble areas to see where memory is not being released.

Thank you for spending to fix bug. I found that memory usage becomes increase when updating the embedding. The problem was in model code

    self.embedding = nn.Embedding.from_pretrained(embedding)
    if embedding is not None:
        self.embedding.weight = nn.Parameter(embedding)
    self.embedding.weight.requires_grad = False  #True

Memory problem was triggered when requires_grad = True when using the pretrained embedding.