How to implement Pytorch equivalent of Keras' kernel weight regulariser

Hi, I wanted to implement a pytorch equivalent of keras code mentioned below.

    self.regularizer = self.L2_offdiag(l2 = 1) #Initialised with arbitrary value

    Dense(classes, input_shape=[classes], activation="softmax", 
    kernel_initializer=keras.initializers.Identity(gain=1), 
    bias_initializer="zeros",kernel_regularizer=self.regularizer, 
    bias_regularizer=keras.regularizers.l2(l=1))

    class Regularizer(object):
        """
        Regularizer base class.
        """

        def __call__(self, x):
            return 0.0

        @classmethod
        def from_config(cls, config):
            return cls(**config)


    class L2_offdiag(Regularizer):
        """
        Regularizer for L2 regularization off diagonal.
        """

        def __init__(self, l2=0.0):
        
            """
            Params:
                l: (float) lambda, L2 regularization factor.
            """
            self.l2 = K.cast_to_floatx(l2)

        def __call__(self, x):
            """
            Off-diagonal regularization (complementary regularization)
            """

            reg = 0

            for i in range(0, x.shape[0]):
                reg += K.sum(self.l2 * K.square(x[0:i, i]))
                reg += K.sum(self.l2 * K.square(x[i+1:, i]))
                
            return reg


I understand the I can use torch.nn.Linear(classes,classes) to create the the layers matching the keras implementation. But how do I integrate regularisers and bias in the layer?

The bias parameter is automatically added to nn.Linear, if you don’t explicitly set bias=False in its creation.
To add a regularization term for the weight parameter, you could manually add it to the loss:

output = model(input)
loss = criterion(output, target)
loss = loss + torch.norm(model.layer.weight, p=2)
4 Likes

Thanks, is there a default L2 regularisation which I need to disable explicitly before doing the normalisation myself?

Optimizers accept the weight_decay argument, which adds the L2 penalty to the loss, so you might want to set it to 0.

1 Like