RAM usage keeps increasing while training deep learning model in pytorch

I am training a deep learning model for unsupervised domain adaptation and I have this issue that while training the RAM usage keeps going up while I actually expect that the iteration i should take the same RAM of iteration i-1. I read that in some cases the dataloaders had some problems of RAM usage, so I tried to load the data “manually” to see if that was the problem, but apparently it’s not.
Here is my code:

class GCANDataset(Dataset):

def __init__(self, source_path, target_path, class_list, train_portion=0.8, transform=None):

    self.transform = transform
    self.source_path = source_path
    self.target_path = target_path
    self.file_counter = 0
    dict = {}
    
    self.source_images = list()
    self.source_labels = list()
    self.target_images = list()
    self.target_labels = list()
    source_length = 0
    
    for folder in os.listdir(source_path):
        full_folder = os.path.join(source_path, folder)
        for img in os.listdir(full_folder):
            full_path = os.path.join(full_folder, img)
            self.file_counter += 1
            source_length += 1
            self.source_images.append(Image.open(full_path))
            self.source_labels.append(folder)
            

    for folder in os.listdir(target_path):
        full_folder = os.path.join(target_path, folder)
        for img in os.listdir(full_folder):
            full_path = os.path.join(full_folder, img)
            self.file_counter += 1
            self.target_images.append(Image.open(full_path))
            self.target_labels.append(folder)

    for i in range(len(class_list)):
        dict[class_list[i]] = i
    
    
    self.number_of_train = int(train_portion * self.file_counter)
    self.number_of_test = self.file_counter - self.number_of_train

    indexes = list(range(len(self.target_images)))

    target_test_split = random.sample(indexes, self.number_of_test)
    target_train_split = [item for item in indexes if item not in target_test_split]
    self.train_images = self.source_images + [self.target_images[i] for i in target_train_split]
    self.train_labels = self.source_labels + [self.target_labels[i] for i in target_train_split]
    self.train_domain = [0.] * source_length + [1.] * len(target_train_split)
    self.test_images = [self.target_images[i] for i in target_test_split]
    self.test_labels = [self.target_labels[i] for i in target_test_split]

    self.train_labels = [torch.Tensor([float(dict[label])]).to(torch.long) for label in self.train_labels]
    self.test_labels = [torch.Tensor([float(dict[label])]).to(torch.long) for label in self.test_labels]

    temp = list(zip(self.test_images, self.test_labels))
    random.shuffle(temp)
    self.test_images, self.test_labels = zip(*temp)
    self.test_images, self.test_labels = list(self.test_images), list(self.test_labels)

    temp = list(zip(self.train_images, self.train_labels, self.train_domain))
    random.shuffle(temp)
    self.train_images, self.train_labels, self.train_domain = zip(*temp)
    self.train_images, self.train_labels, self.train_domain = list(self.train_images), list(self.train_labels), list(self.train_domain)


def __getitem__(self, index):

    if index < self.number_of_train:
        image = self.train_images[index]
        label = self.train_labels[index]
        domain = self.train_domain[index]
    else:
        index -= self.number_of_train
        image = self.test_images[index]
        label = self.test_labels[index]
        domain = torch.Tensor([1.])

    if self.transform is not None:
        image = self.transform(image)
    
    return image, label[0].to(torch.float), torch.Tensor([domain]).to(torch.long)


def __len__(self):
    return self.file_counter




def divide_source_and_target(x, pred, target, domain):

    x_source = np.array(list())
    x_target = np.array(list())
    pred_source = np.array(list())
    pred_target = np.array(list())
    target_source = np.array(list())
    target_target = np.array(list())

    for i in range(len(domain)):
        if domain[i] == 0:
            np.append(x_source, x[i].detach().cpu().numpy())
            np.append(pred_source, pred[i].detach().cpu().numpy())
            np.append(target_source, target[i].detach().cpu().numpy())
        else:
            np.append(x_target, x[i].detach().cpu().numpy())
            np.append(pred_target, pred[i].detach().cpu().numpy())
            np.append(target_target, target[i].detach().cpu().numpy())

    return x_source, x_target, pred_source, pred_target, target_source, target_target



def uda_classification_loss(predicted, target):

   loss_function = nn.CrossEntropyLoss()
   return loss_function(torch.from_numpy(predicted), torch.from_numpy(target))


def uda_domain_alignment_loss(domain_pred, domain_target):

    loss_function = nn.BCELoss()
    return loss_function(domain_pred.to(torch.float), domain_target.to(torch.float))

def uda_structure_aware_alignment_loss(x_source, x_target, threshold=1):

    if len(x_source) >= 2:
        first_img_source, second_img_source = random.sample(x_source, 2)
    else:
        return 0

    if len(x_target) >= 1:
        target_img = random.sample(x_target, 1)
    else:
        return 0

    first_img_source = first_img_source.cpu().numpy()
    second_img_source = second_img_source.cpu().numpy()
    target_img = target_img[0].cpu().numpy()

    score_source, _ = ssim(first_img_source, second_img_source, win_size=3, full=True, multichannel=True)
    score_target, _ = ssim(first_img_source, target_img, win_size=3, full=True, multichannel=True)

    score_source = score_source ** 2
    score_target = score_target ** 2 

    final_value = score_source + score_target + threshold

    return max(final_value, 0)


def uda_class_alignment_loss(x, domains, pseudo_classes, classes):

    source_classes = {}
    target_classes = {}

    temp_x = x.detach().cpu()
    temp_domains = domains.detach().cpu()
    temp_pseudo_classes = pseudo_classes.detach().cpu()
    temp_classes = classes.detach().cpu()

    for i in range(len(domains)):

        if domains[i] == 0:
             key = temp_classes[i]
        
             if temp_classes[i] in source_classes:
                 source_classes[key].append(temp_x[i])
             else:
                 source_classes[key] = list(temp_x[i])
        else:
             key = temp_pseudo_classes[i]
         
             if temp_pseudo_classes[i] in target_classes:
                 target_classes[key].append(temp_x[i])
             else:
                 target_classes[key] = list(temp_x[i])


    for key in source_classes:

         values = source_classes[key]
         matrix = list()
         for arr in values:
              matrix.append(arr)
    
         sum_arr = np.sum(np.array(matrix), axis=0)
         sum_arr /= len(values)

    for key in target_classes:

        values = target_classes[key]
        matrix = list()
        for arr in values:
            matrix.append(arr)

        sum_arr = np.sum(np.array(matrix), axis=0)
        sum_arr /= len(values)


  sum_distance = 0

  for key in source_classes:
    
      if key in target_classes:
        
          source_prototype = source_classes[key]
          target_prototype = target_classes[key]
          dist = euclidean(source_prototype, target_prototype) ** 2
          sum_distance += dist

   return sum_distance



def uda_loss(x, class_prediction, domain_prediction, target_class, target_domain, mid_results):

    class_prediction_weight = 1
    domain_prediction_weight = 0.001
    structure_aware_alignment_weight = 0.001
    class_alignment_loss = 0.001

    x_source, x_target, class_prediction_source, class_prediction_target, target_class_source, target_class_target = \
            divide_source_and_target(x, class_prediction, target_class, target_domain)

    classification_loss = uda_classification_loss(class_prediction_source, target_class_source)
    domain_loss = uda_domain_alignment_loss(domain_prediction, target_domain)
    triplet_loss = uda_structure_aware_alignment_loss(x_source, x_target)
    class_alignment_loss = uda_class_alignment_loss(mid_results, target_domain, class_prediction, target_class)
    

    return class_prediction_weight * classification_loss + domain_prediction_weight * domain_loss + \
            structure_aware_alignment_weight * triplet_loss + class_alignment_loss * class_alignment_loss


def get_batch_indexes(data, batch_size):
    
    # Dropping last incomplete batch
    train_batches_indexes = list(range(0, data.number_of_train))
    last_batch_length = len(train_batches_indexes) % batch_size
    last_batch_sample_indexes = random.sample(train_batches_indexes, last_batch_length)
    last_batch_indexes = [train_batches_indexes.remove(index) for index in last_batch_sample_indexes]
    random.shuffle(train_batches_indexes)

    train_batches_indexes = np.array(train_batches_indexes).reshape(-1, batch_size)


    test_batches_indexes = list(range(data.number_of_train, len(data)))
    last_batch_length = len(test_batches_indexes) % batch_size
    last_batch_sample_indexes = random.sample(test_batches_indexes, last_batch_length)
    last_batch_indexes = [test_batches_indexes.remove(index) for index in last_batch_sample_indexes]
    random.shuffle(test_batches_indexes)

    test_batches_indexes = np.array(test_batches_indexes).reshape(-1, batch_size)

    return train_batches_indexes, test_batches_indexes, train_batches_indexes.shape[0], test_batches_indexes.shape[0]


def take_batch(data, idx, batch_indexes):

    input = np.array(list())
    target = np.array(list())
    domain = np.array(list())

    for i in batch_indexes[idx]:

        input = np.append(input, data[i][0])
        target = np.append(target, data[i][1])
        domain = np.append(domain, data[i][2])
    
    input = input.reshape(-1, 3, 224, 224)
    domain = domain.reshape(-1, 1)

    return input, target, domain



def forward_and_backward(input, real_class, domain, model, optimizer):

    input = torch.from_numpy(input).to(torch.float).to(device)
    real_class = torch.from_numpy(real_class).to(torch.long).to(device)
    domain = torch.from_numpy(domain).to(torch.long).to(device)

    output = model.forward(input)
    domain_classification, pseudo_label, mid_out = output[0], output[1], output[2]

    cuda.empty_cache()

    loss = uda_loss(input, pseudo_label, domain_classification, 
                        real_class, domain, mid_out)
    loss.backward()

    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    _, predicted = pseudo_label.max(1)
    acc_to_add = predicted.eq(real_class).sum().item()

    return loss.item(), acc_to_add


def train_uda_epoch(model, data, train_batches, optimizer, train_num, batch_size, device='cuda:0'):
    
    samples = 0
    cumulative_acc = 0
    cumulative_loss = 0

    model.train()

    for i in tqdm(range(train_num)):

        gc.collect()
        print('Begin batch RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)
        cuda.empty_cache()

        input, real_class, domain = take_batch(data, i, train_batches)  # Drops last batch

        loss_value, acc_to_add = forward_and_backward(input, real_class, 
                                                         domain, model, 
                                                         optimizer)

        samples += batch_size
        cumulative_loss += loss_value
        cumulative_acc += acc_to_add        

        gc.collect()
        cuda.empty_cache()

    return (cumulative_acc / samples) * 100, cumulative_loss / samples 



def train_uda_model(train_epochs, model, dataset, optimizer, batch_size, counter, device='cuda:0'):
    
    dir = 'exp_uda' + str(counter)
    path = os.path.join('runs', dir)
    writer = SummaryWriter(log_dir=path)

    start = time.time()

    for epoch in range(train_epochs):

        gc.collect()
        print('Begin epoch RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)
        cuda.empty_cache()

        print(f'Creating batches for epoch {epoch + 1}... ', end='')
        train_batches, test_batches, n_train_batches, n_test_batches = get_batch_indexes(dataset, batch_size)
        print('Done')
        
        print(f'Epoch {epoch + 1} / {train_epochs}')
        acc, loss = train_uda_epoch(model, dataset, train_batches , optimizer, n_train_batches, batch_size, device)
        print(f'Test accuracy of the (UDA) model after {epoch + 1} epochs: {acc}%')
        print(f'Loss of the (UDA) model after {epoch + 1} epochs: {loss}')

        gc.collect()
        cuda.empty_cache()

        writer.add_scalar('Loss/train_loss', loss, epoch + 1)
        writer.add_scalar('Loss/train_accuracy', acc, epoch + 1)
    
    writer.flush()
    writer.close()
    
    end = time.time()
    print(f'The training finished in {end - start} seconds')
    print('-' * 70)

The ouput I get during the first epoch of training is

Epoch 1 / 10

0%| | 0/50 [00:00<?, ?it/s]

Begin batch RAM Used (GB): 5.10459904

2%|▏ | 1/50 [00:06<05:34, 6.83s/it]

Begin batch RAM Used (GB): 5.63181568

4%|▍ | 2/50 [00:13<05:21, 6.69s/it]

Begin batch RAM Used (GB): 6.180130816

6%|▌ | 3/50 [00:19<05:05, 6.50s/it]

Begin batch RAM Used (GB): 6.423826432

8%|▊ | 4/50 [00:26<04:58, 6.49s/it]

Begin batch RAM Used (GB): 6.811430912

10%|█ | 5/50 [00:32<04:53, 6.52s/it]

Begin batch RAM Used (GB): 7.094808576

12%|█▏ | 6/50 [00:39<04:48, 6.57s/it]

Begin batch RAM Used (GB): 7.477551104

14%|█▍ | 7/50 [00:46<04:49, 6.73s/it]

Begin batch RAM Used (GB): 8.150032384

16%|█▌ | 8/50 [00:53<04:41, 6.71s/it]

Begin batch RAM Used (GB): 8.372457472

18%|█▊ | 9/50 [00:59<04:32, 6.65s/it]

Begin batch RAM Used (GB): 8.723787776

20%|██ | 10/50 [01:06<04:22, 6.56s/it]

Begin batch RAM Used (GB): 9.079734272

22%|██▏ | 11/50 [01:12<04:15, 6.56s/it]

Begin batch RAM Used (GB): 9.56788736

24%|██▍ | 12/50 [01:19<04:11, 6.61s/it]

Begin batch RAM Used (GB): 9.789124608

26%|██▌ | 13/50 [01:26<04:07, 6.69s/it]

Begin batch RAM Used (GB): 10.292166656

28%|██▊ | 14/50 [01:32<04:00, 6.67s/it]

Begin batch RAM Used (GB): 10.670362624

30%|███ | 15/50 [01:41<04:11, 7.20s/it]

Begin batch RAM Used (GB): 10.951622656

32%|███▏ | 16/50 [01:49<04:19, 7.62s/it]

Begin batch RAM Used (GB): 11.328094208

34%|███▍ | 17/50 [01:56<04:04, 7.41s/it]

Begin batch RAM Used (GB): 11.705749504

36%|███▌ | 18/50 [02:03<03:53, 7.29s/it]

Begin batch RAM Used (GB): 12.08178688

38%|███▊ | 19/50 [02:10<03:44, 7.23s/it]

Begin batch RAM Used (GB): 12.376731648

40%|████ | 20/50 [02:20<03:58, 7.94s/it]

Begin batch RAM Used (GB): 12.552052736

42%|████▏ | 21/50 [02:27<03:46, 7.82s/it]

Begin batch RAM Used (GB): 12.56220672

It is hard to completely replicate your case but here’s what I guess.

Rewriting the above code to minimize the usage of .to() and .cpu() and see how it changed.

First of all, thanks for your answer, but unfortunately this is not the case. I tried to do as you said, and removed all the unnecessary .to() and .cpu() but unfortunately this didn’t solve the problem. Moreover, I investigated a bit more into the precise point in which I have this problem and it’s in particular the forward pass, which is defined as follows:

def forward(self, x):
        
        features = self.cnn.forward(x)
        scores = self.dsa.forward(x)

        transposed_scores = torch.transpose(scores, 0, 1)
        adjacency_matrix = torch.matmul(scores, transposed_scores)
        sparse_adj_matrix = dense_to_sparse(adjacency_matrix)       
        
        edge_index, edge_attr = sparse_adj_matrix[0], sparse_adj_matrix[1]
        graph = geometric_data(scores, edge_index=edge_index)
 
        gcn_features = self.gcn(graph.x, graph.edge_index)
        gcn_features = gcn_features.view(-1, 150, 1, 1)
        
        concat_features = torch.cat((features, gcn_features), dim=1)
        concat_features = concat_features.view(-1, self.combined_features)
        
        domain_classification = self.domain_alignment(concat_features)

        pseudo_label = relu(self.fc1(concat_features))
        mid_out = self.fc2(pseudo_label)
        pseudo_label = relu(mid_out)
        pseudo_label = softmax(self.fc3(pseudo_label), dim=1)
        
        
        return domain_classification, pseudo_label, mid_out

In this piece of code, cnn is a resnet50 (pretrained) without the last layer, dsa is a complete resnet50 (pretrained), gcn is the GCN network from python geometric, domain alignment is a squential module with 3 fully connected layer and fcs are fully connected layers.