L1 Regularization Kaggle Example - But it works!

Hi guys, I am working with a regulized network since some months. I also improved the performance last winter by applying L1 regularization onto it. The network worked exactly like it should and improved. I did not look so detailed into the code for a while. It is a simple MLP with 3 layers and some dropout … A view days a ago I improved the performance again by using different acivation function and saw it:

The l1 term is not added to the loss. It is substracted from it. When I debugged it, i also realized that L1_term * l1_lambda is always positive. I corrected the “error” but the performance (MAE; it is a regression task) got worse.

It seems that I have the original code from a kaggle example: Minimizing Loss using L1 Regularization in Pytorch | Kaggle

Do you have any idea? Thanks a lot.

Here I paste my training loop (I will mark the line where I had the ‘error’ with !!! replace the + by -) :

# training loop
for epoch in range(0,n_epochs):
    model.train()
    model = model.to(device)
    losses_sum = 0
    
    print(f'starting epoch {epoch+1}')

    logging.info(f'starting epoch {epoch+1}')
    
    reg = 0

    for start in range(X_train.shape[0]):
    #for start in range(5):#for debugging
        # take a batch
        X_batch = X_train[start:start+batch_size]
        y_batch = y_train[start:start+batch_size]
        #print(type(X_batch))

        # forward pass
        
        y_pred = model.predict(X_batch)
        y_batch = torch.Tensor(y_batch)
        y_pred = y_pred.cpu()
        y_batch = y_batch.cpu()

        loss = loss_fn(y_pred, y_batch)
        
        # Calculate L1 term
        L1_term = torch.tensor(0., requires_grad=True)
        for name, weights in model.named_parameters():
            if 'bias' not in name:
                weights_sum = torch.sum(torch.abs(weights))
                L1_term = L1_term + weights_sum

        L1_term = L1_term / nweights
        
        # Regularize loss using L1 regularization
        loss = loss + L1_term * l1_lambda !!!

        reg = reg + L1_term * l1_lambda

        writer.add_scalar("Loss/train", loss, epoch)
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        # update weights
        optimizer.step()
        # print progress
         
        losses_sum += loss
    
    print('Reg_sum:')
    print(reg)

I notice that you’re logging your training loss but not your validation loss. The purpose of regularization is to prevent overfitting the training data. It might be possible that by subtracting your regularization term you are encouraging overfitting, which is why you’re getting better training loss? I would be curious to see if validation loss goes down. It might just be learning to increase the weights too.

Also welcome to the forum!

thanx for the reply and welcoming me to the forum :slight_smile:
This is just the training part. the validation loss is also calculated.So when I say improve I mean on the validation set. I will insert the code below

    writer.flush()
    
    #validation loss
    losses_sum_val = 0
    
    for start in range(X_test.shape[0]):
        
        X_batch = X_test[start:start+batch_size]
        y_batch = y_test[start:start+batch_size]
        
        # forward pass
        y_predicted_val = model.predict(X_batch)

        y_predicted_val = y_predicted_val.cpu()
        y_batch = torch.Tensor(y_batch).to('cpu')

        loss_val = loss_fn(y_predicted_val, y_batch)
        
        # Regularize loss using L1 regularization
        loss_val = loss_val + L1_term * l1_lambda

        losses_sum_val += loss_val  

    """
    y_pred = y_pred.cpu()
    y_test = torch.Tensor(y_test).cpu()
    mse = 0
    mse = loss_fn(y_pred, y_test)
    mse = float(mse)
    logging.debug(f'mse: {mse}')
    """
    mean_loss = float(losses_sum/X_train.shape[0])
    print(mean_loss)
    logging.debug(f'mean Loss: {mean_loss}')

    mean_loss_val = float(losses_sum_val/X_test.shape[0])
    print(mean_loss_val)
    logging.debug(f'mean Loss validation: {mean_loss_val}')
    
    # evaluate accuracy at end of each epoch
    model.eval()
    
    y_pred = model.predict(X_test)
    y_predicted_train = model.predict(X_train)

    y_pred = y_pred.cpu().detach().numpy()
    y_predicted_train = y_predicted_train.cpu().detach().numpy()

    from sklearn.metrics import mean_absolute_error as mae
    r2_model_KORA, pearson_model_KORA, yhat_linear_KORA = evaluation(model, X_train.to('cpu'), X_test.to('cpu'), y_train.to('cpu'), y_test, True)
    #y_test = y_test.cpu().detach().numpy()
    y_test = torch.Tensor(y_test).cpu()
    y_train = torch.Tensor(y_train).cpu()

    MAE = mae(y_test, y_pred)
    MAE_train = mae(y_train, y_predicted_train)

    print(MAE)
    print(MAE_train)

    maes

There may be something funky in your code. I just adapted an MNIST example to have L1 regularization with negative lambda just to double check, and I saw the training loss shoot down and the validation loss shoot up.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

batch_size = 64
learning_rate = 0.01
epochs = 10
lambda_l1 = -0.01  # Regularization strength, can be negative to see the effect

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        # Adding L1 regularization
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        loss += lambda_l1 * l1_norm
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    train_loss = running_loss / len(train_loader)
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    
    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

output where lambda = -0.01

Epoch 1/10, Training Loss: -72.5726, Validation Loss: 0.4573
Epoch 2/10, Training Loss: -175.2643, Validation Loss: 0.6409
Epoch 3/10, Training Loss: -276.9435, Validation Loss: 0.8348
Epoch 4/10, Training Loss: -378.4700, Validation Loss: 1.1148
Epoch 5/10, Training Loss: -480.1062, Validation Loss: 0.9174
Epoch 6/10, Training Loss: -581.9961, Validation Loss: 1.1283
Epoch 7/10, Training Loss: -684.0639, Validation Loss: 1.2179
Epoch 8/10, Training Loss: -786.1513, Validation Loss: 1.1799
Epoch 9/10, Training Loss: -888.2631, Validation Loss: 1.2234
Epoch 10/10, Training Loss: -990.3663, Validation Loss: 1.4981

I also used a positive lambda version of the same lambda value and I saw the training loss go down for one epoch, but then training and validation stayed almost exactly the same.

output where lambda = 0.01

Epoch 1/10, Training Loss: 6.0010, Validation Loss: 2.3024
Epoch 2/10, Training Loss: 2.3657, Validation Loss: 2.3023
Epoch 3/10, Training Loss: 2.3573, Validation Loss: 2.3023
Epoch 4/10, Training Loss: 2.3573, Validation Loss: 2.3023
Epoch 5/10, Training Loss: 2.3573, Validation Loss: 2.3023
Epoch 6/10, Training Loss: 2.3573, Validation Loss: 2.3023
Epoch 7/10, Training Loss: 2.3573, Validation Loss: 2.3022
Epoch 8/10, Training Loss: 2.3573, Validation Loss: 2.3022
Epoch 9/10, Training Loss: 2.3573, Validation Loss: 2.3022
Epoch 10/10, Training Loss: 2.3573, Validation Loss: 2.3022

So I tried a few different values of lambda, decreasing it by an order of magnitude yielded better results.

output where lambda = 0.001

Epoch 1/10, Training Loss: 2.8704, Validation Loss: 0.4596
Epoch 2/10, Training Loss: 1.6505, Validation Loss: 0.3784
Epoch 3/10, Training Loss: 1.1952, Validation Loss: 0.3734
Epoch 4/10, Training Loss: 0.9795, Validation Loss: 0.3638
Epoch 5/10, Training Loss: 0.8832, Validation Loss: 0.3490
Epoch 6/10, Training Loss: 0.8180, Validation Loss: 0.3296
Epoch 7/10, Training Loss: 0.7662, Validation Loss: 0.3146
Epoch 8/10, Training Loss: 0.7252, Validation Loss: 0.3153
Epoch 9/10, Training Loss: 0.6919, Validation Loss: 0.3005
Epoch 10/10, Training Loss: 0.6642, Validation Loss: 0.3016

Oddly enough, I tried -0.001 and it did better than the positive version.

output where lambda = -0.001

Epoch 1/10, Training Loss: -1.7663, Validation Loss: 0.4144
Epoch 2/10, Training Loss: -3.4858, Validation Loss: 0.3253
Epoch 3/10, Training Loss: -4.5694, Validation Loss: 0.2775
Epoch 4/10, Training Loss: -5.6204, Validation Loss: 0.2814
Epoch 5/10, Training Loss: -6.6634, Validation Loss: 0.2320
Epoch 6/10, Training Loss: -7.7021, Validation Loss: 0.2158
Epoch 7/10, Training Loss: -8.7343, Validation Loss: 0.2076
Epoch 8/10, Training Loss: -9.7614, Validation Loss: 0.2049
Epoch 9/10, Training Loss: -10.7848, Validation Loss: 0.1928
Epoch 10/10, Training Loss: -11.8046, Validation Loss: 0.1898

…but ultimately, what performed best was no L1 regularization at all.

output where lambda = 0.0

Epoch 1/10, Training Loss: 1.0738, Validation Loss: 0.4299
Epoch 2/10, Training Loss: 0.3833, Validation Loss: 0.3256
Epoch 3/10, Training Loss: 0.3215, Validation Loss: 0.2919
Epoch 4/10, Training Loss: 0.2900, Validation Loss: 0.2741
Epoch 5/10, Training Loss: 0.2659, Validation Loss: 0.2505
Epoch 6/10, Training Loss: 0.2451, Validation Loss: 0.2295
Epoch 7/10, Training Loss: 0.2251, Validation Loss: 0.2161
Epoch 8/10, Training Loss: 0.2067, Validation Loss: 0.1984
Epoch 9/10, Training Loss: 0.1904, Validation Loss: 0.1891
Epoch 10/10, Training Loss: 0.1761, Validation Loss: 0.1710

Note: training loss is higher than validation in many of these because of the added regularization term

Haha, okay thanks for that nice experiment … I also tried now with smal positive values on my dataset and got better results (But not a lot better)… Maybe it is not a sign error, but just some random effect …

1 Like