Delete a tensor with pytorch + NN

While training a model, I have noticed that the “forward pass” does not release memory. So I have been trying to delete manually those tensors to control memory

import gc
import os
import torch
import torch.nn as nn
import numpy as np
from torch import cuda


d = 2
l = 170
M = 22
N = 10000

I = torch.eye(M + 3, M + 3)

net = nn.Sequential(
    nn.Linear(d, l),
    nn.Tanh(),
    nn.Linear(l, l),
    nn.Tanh(),
    nn.Linear(l, l),
    nn.Tanh(),
    nn.Linear(l, M),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

optimizer = optim.Adam(net.parameters(), lr=0.001) 

def data_Preprocessing(tr_val_te, cut):
        data = np.loadtxt(('./%s_%s.csv' % (data_name, tr_val_te)), delimiter=',', dtype=np.float64)[:cut]
        data = torch.tensor(data, dtype=torch.float32)
        return data

x_data = data_Preprocessing("train_x", N)
y_data = data_Preprocessing("train_y", N)

#forward pass:

pred_sai = net(x_data)
y_pred_sai = net(y_data)

#loss calculation, I have not copied to simplify the problem


optimizer.zero_grad()

loss.backward()


#DELETE THE TENSORS
del pred_sai, y_pred_sai
gc.collect()

optimizer.step()

I have done an analysis of the memory according to memory-profiler · PyPI

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    28 156.79296875 MiB 156.79296875 MiB           1   @profile(precision=8)
    29                                         
    30                                         def my_func():
    31                                             
    32 156.79296875 MiB   0.00000000 MiB           1       data_name = 'Duffing_oscillator' # spectrum    Duffing_oscillator
    33                                             
    34 156.79296875 MiB   0.00000000 MiB           1       lambda_ = 1e-2  
    35                                             
    36 156.79296875 MiB   0.00000000 MiB           1       epsilon = 20
    37 156.79296875 MiB   0.00000000 MiB           1       d = 2
    38 156.79296875 MiB   0.00000000 MiB           1       l = 170  
    39 156.79296875 MiB   0.00000000 MiB           1       M = 22
    40                                             
    41                                             
    42 157.21484375 MiB   0.42187500 MiB           1       I = torch.eye(M + 3, M + 3)
    43                                             
    44                                             
    45 157.21484375 MiB   0.00000000 MiB           1       N = 10000
    46 157.21484375 MiB   0.00000000 MiB           1       inv_N = 1/N  #0.1
    47                                             
    48 157.21484375 MiB   0.00000000 MiB           1       net = nn.Sequential(
    49 157.57812500 MiB   0.36328125 MiB           1           nn.Linear(d, l),
    50 157.57812500 MiB   0.00000000 MiB           1           nn.Tanh(),
    51 157.57812500 MiB   0.00000000 MiB           1           nn.Linear(l, l),
    52 157.57812500 MiB   0.00000000 MiB           1           nn.Tanh(),
    53 157.82031250 MiB   0.24218750 MiB           1           nn.Linear(l, l),
    54 157.82031250 MiB   0.00000000 MiB           1           nn.Tanh(),
    55 157.82031250 MiB   0.00000000 MiB           1           nn.Linear(l, M),
    56                                             )
    57                                             
    58                                             
    59 157.82031250 MiB   0.00000000 MiB           1       learning_rate = 0.001
    60                                             
    61 157.82031250 MiB   0.00000000 MiB           1       optimizer = optim.Adam(net.parameters(), lr=learning_rate) # amsgrad=True)
    62                                             
    63 157.82031250 MiB   0.00000000 MiB           1       loss_fn = torch.nn.MSELoss(reduction='sum')
    64                                             
    65                                             
    66 161.80078125 MiB   0.00000000 MiB           3       def data_Preprocessing(tr_val_te, cut):
    67 162.98828125 MiB   5.16796875 MiB           2           data = np.loadtxt(('./%s_%s.csv' % (data_name, tr_val_te)), delimiter=',', dtype=np.float64)[:cut]
    68 162.98828125 MiB   0.00000000 MiB           2           data = torch.tensor(data, dtype=torch.float32)
    69 162.98828125 MiB   0.00000000 MiB           2           return data
    70                                             
    71                                             
    72 272.84765625 MiB   0.00000000 MiB           2       def Frobenius_norm(X):
    73 272.84765625 MiB   0.00000000 MiB           1           M = torch.mm(X, torch.transpose(X, 0, 1))
    74 272.84765625 MiB   0.00000000 MiB           1           return torch.sum(torch.diag(M, 0))
    75                                             
    76                                             
    77                                                 
    78                                                 
    79 157.82031250 MiB   0.00000000 MiB           1       x = []
    80 157.82031250 MiB   0.00000000 MiB           1       y = []
    81 157.82031250 MiB   0.00000000 MiB           1       K_tilde = []
    82                                             
    83                                             
    84                                             #net input
    85 161.80078125 MiB   0.00000000 MiB           1       x_data = data_Preprocessing("train_x", N)
    86 162.98828125 MiB   0.00000000 MiB           1       y_data = data_Preprocessing("train_y", N)
    87                                             
    88                                             
    89 162.98828125 MiB   0.00000000 MiB           1       loss = float("INF")
    90                                             
    91 163.39062500 MiB   0.26562500 MiB       10003       fixed_sai = torch.tensor([i + [0.1] for i in x_data.detach().tolist()], dtype=torch.float32)
    92 165.37500000 MiB   1.98437500 MiB       10003       y_fixed_sai = torch.tensor([i + [0.1] for i in y_data.detach().tolist()], dtype=torch.float32)
    93                                             
    94                                             
    95 235.71484375 MiB  70.33984375 MiB           1       pred_sai = net(x_data)  
    96 271.46875000 MiB  35.75390625 MiB           1       y_pred_sai = net(y_data)
    97                                         
    98 271.69531250 MiB   0.22656250 MiB           1       pred_sai = torch.cat([pred_sai, fixed_sai], dim=1)
    99 271.69531250 MiB   0.00000000 MiB           1       y_pred_sai = torch.cat([y_pred_sai, y_fixed_sai], dim=1)
   100                                         
   101                                         
   102 271.69531250 MiB   0.00000000 MiB           1       pred_sai_T = torch.transpose(pred_sai, 0, 1)
   103 271.69531250 MiB   0.00000000 MiB           1       y_pred_sai_T = torch.transpose(y_pred_sai, 0, 1)
   104                                         
   105                                              
   106 272.84765625 MiB   1.15234375 MiB           1       K_tilde = torch.mm(torch.pinverse(inv_N * torch.mm(pred_sai_T, pred_sai)  + lambda_ * I), inv_N * torch.mm(pred_sai_T, y_pred_sai))
   107 272.84765625 MiB   0.00000000 MiB           1       K_tilde = K_tilde.detach()
   108                                         
   109                                                
   110 273.03906250 MiB   0.19140625 MiB           1       res = lambda_ * Frobenius_norm(K_tilde) ** 2  
   111 273.53906250 MiB   0.50000000 MiB           1       MSE = (y_pred_sai_T - torch.mm(K_tilde, pred_sai_T))** 2  
   112 273.53906250 MiB   0.00000000 MiB           1       loss = torch.sum(MSE) +res
   113                                         
   114                                         
   115                                         
   116 273.53906250 MiB   0.00000000 MiB           1       optimizer.zero_grad()   
   117 313.44921875 MiB  39.91015625 MiB           1       loss.backward()    
   118                                             

As you can see, gc.collect() is not working.

  120                                             ###########################
   121 313.44921875 MiB   0.00000000 MiB           1       del pred_sai,y_pred_sai 
   122 312.71484375 MiB  -0.73437500 MiB           1       gc.collect()
   123                                         ################################

I have also tried to implement what is said in this web: avoiding full gpu memory occupation during training in pytorch - Chadrick's Blog

But it does not solve the problem.

I was wondering if you have any suggestions.