Question about CPU memory usage during training

Hi guys,
I’m training my model using pytorch.
I created a simple neural network with 2 layers training on MNIST dataset, and applied a custom method named LS on every neuron between two hidden layers.
However, during the training, my memory usage constantly increased over time so I’m running out of memory and the process was killed in the first epoch.
Below is my implementation of the neural network and the training process:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28*1, 20)
        self.fc2 = nn.Linear(20, 10) 
    def forward(self, x):
        x1 = self.fc1(x)
        x2 = torch.sigmoid(x1)
        x3 = self.fc2(x2)
        com_x3 = torch.sigmoid(x3)
        x2 = x2.clone().requires_grad_()
        for i in range(x2.shape[0]):
            for j in range(x2.shape[1]):
                for k in range(com_x3.shape[1]):
                    # Applying
                    x2[i,j] = LS(x2[i,j],com_x3[i,k])
        return x3

def train(epoch):
    train_loss = 0  
    train_accuracy = 0  
    model.train()
    for data, label in tqdm(loader_train, desc="Training"):
        data, label = data.view(-1, 28*28).to(device), label.to(device)
        optimizer.zero_grad()  
        y_pred_prob = model(data)  
        torch.autograd.set_detect_anomaly(True)
        loss = loss_fn(y_pred_prob, label)  
        loss.backward()  
        optimizer.step()  
        train_loss += loss.item()
        y_pred_label = torch.max(y_pred_prob, 1)[1]
        train_accuracy += torch.sum(y_pred_label == label).item() / len(label)

I’m using pytorch 2.6, running the code on WSL2 (RAM for WSL is 24Gb). Batch size is 128.
Please help me with this problem.
Thank you and best regards.

Unrelated to your memory issues, but do you mean to return x2 instead of x3 at the end of your forward pass?

Also the triple nested for loop seems like a red flag here, I’d recommend trying to vectorize this but without knowing what LS is I won’t jump to that conclusion.

Here’s my LS method:

def LS(x, y):
    ls_a = x
    ls_b = 1.0 - x
    ls_c = 1.0 - y
    ls_d = y
    
    # Original conditions
    if ls_a == 0.0: ls_a = 1e-03
    if ls_b == 0.0: ls_b = 1e-03
    if ls_c == 0.0: ls_c = 1e-03
    if ls_d == 0.0: ls_d = 1e-03
 
    term_a = (ls_b * ls_d) / (ls_b + ls_d)
    term_b = (ls_a * ls_c) / (ls_a + ls_c)
    term_c = (ls_b * ls_d) / (ls_b + ls_d)
    
    # Original conditions
    if term_a == 0.0: term_a = 1e-03
    if term_b == 0.0: term_b = 1e-03
    if term_c == 0.0: term_c = 1e-03

    ls_res = (ls_a + term_a) / (ls_a + ls_b + term_b + term_c)
    return ls_res

I want to update x2 by LS, and use it during the backward pass.