Can I implement a custom loss function like this?

Dear Pytorch forum, I was wondering if you could help me with an implementation of a custom loss function. In particular, I want to perform some calculations on the model weights during my loss function - for example L1 penalty. My boundary condition is that my loss function is written as a nn.Module. Why? Because I wish to use Skorch for hyper parameter tuning, as Skorch only accepts nn.Module loss module as inputs (at least this is to the best of my understanding. I am aware, that if I wrote my NN in a Pytorch-typical loop, this problem would be much easier as I wouldn’t need to write a class for my loss, however, that way hyper parameter tuning is more cumbersome.

Assume I have a NN defined as follows:


import torch
import torch.nn as nn

class NeuralNet(nn.Module):
    """
    trivial network, please do not read too much into it - its a matter of principle.
    Architecture: some trivial constant number of nodes for each HL, given number of nodes and HL.
    Parameters: 
    ------------------
    input_dim: dimension of input layer, hl: number of hidden layers, nodes: number of nodes per hl
    """
    def __init__(self, input_dim, hl, nodes):
        self.input_dim = input_dim
        self.hl = hl
        self.nodes = nodes
        self.layers = nn.ModuleList()
        for h in range(hl):
            if h == 0:  # input layer
                self.layers.append(nn.Linear(input_dim, nodes))
                self.layers.append(nn.ReLU())
            else:
                self.layers.append(nn.Linear(nodes, nodes))
                self.layers.append(nn.ReLU())
        self.layers.append(nn.Linear(nodes, 1))
        self.model = nn.Sequential(*self.layers)

    def _get_weights():
        """
        Gets model weights only and ignores biases
        note: w = [*layers.parameters()] would include biases I think
        """
        w = []
        for name, param in self.layers.named_parameters():
            if 'weight' in name:
                w.append(param)
        return w

    def forward(self, x):
        """
        Unusual, because it returns 2 elements. why? because I want the weights to be passed to my loss function.
        """
        w = self._get_weights()
        x = self.model(x)
        return x, w


class RegMSE(nn.Module):
    """
    Custom loss function.
    Example: perform L1 penalty on weights. Please ignore if this would make sense, it's a matter of principle.
    """
    def __init__(self, l1):
        super(RegMSE, self).__init__()
        self.l1 = l1

    def forward(self, y_pred, target):
        """
        y_pred now includes the weights also from the forward pass
        """
        # unpack y_pred
        y_hat, weights = y_pred
        # calculate typical MSE loss
        loss = 1/(y_hat.shape[0])*torch.sum((target-y_hat) ** 2)
        l = torch.tensor([0], dtype=torch.float32)
        # L1 penalty
        if self.l1 is not None:
            for p in weights:
                l += torch.norm(p, 1)
            loss += self.l1 * l[0]
        return loss

Now in a training loop, we could have a situation where:

model = NeuralNet(input_dim=10, hl=2, nodes=10)
criterion = RegMSE(l1=1e-4)
ptimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

out = model(X)
loss = criterion(out, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

would this work to implement a L1 penalty? Meaning, would my weights be trained accordingly?

Because consequently, I could use it in skorch:

from skorch import NeuralNetRegressor

net = NeuralNetRegressor(
    module=NeuralNet,
    module__input_dim=X.shape[1],  # Input shape
    module__nodes=10,  # Number of nodes in first hidden layer
    module__hl=5,  # Number of hidden layers
    criterion=RegMSE,  # Custom Loss Function
    criterion__l1=1e-4,  # L1 weight penalty
    batch_size=20, # Batch size
    max_epochs=100, # Epochs
    optimizer=torch.optim.Adam,  # Oprimizer
    optimizer__lr=1e-6,  # Learning Rate
    optimizer__weight_decay=0,  # No weight penalisation in optimizer
    iterator_train__drop_last=True,  # Avoid empty batch
    verbose=True  # Print learning
)

I want to use skorch to find an “optimal” l1 penalty using RandomizedSearchCV for example.

I really hope this makes sense, and I am grateful for any feedback! Perhaps this is not a performant solution, so if you can think of better ways of doing this, please let me know.

Felix