Learning hyper-paraeter for loss function regularization term

I was wondering, in the case of a regularized loss function, is it possible to learn a hyper-parameter for the loss function while training the network in PyTorch? I have looked at many research papers and it seems everyone is “tuning” the hyper-parameter than learning.

What kind of hyper-parameter would you like to tune?
Usually the hyper-parameters of the training routine are e.g. the learning rate, weight decay etc.
You could tune it using your training and validation loss, but I’m not sure I understand the connection to the loss function completely.

Thanks for the response. There is a method for imposing physical constraints on the neural network, in which a physics-based loss is added to the loss function. This term is usually a function of the output of the network.

As a simplistic example, assume the network outputs a number, which should not fall below “3”. We add a (penalty) regularization term of the form “max(3-output, 0)” that penalizes when output is below 3 and is canceled when output is above 3.

As this is essentially a constrained optimization (minimizing NN loss function subject to the physical constraint above), which is turned into a regularized unconstraint optimization (in the form of a Lagrangian function), we need to find a Lagrangian multiplier for the regularization term. This Lagrangian multiplier is what I’m wondering if is possible to learn, rather than tune (as suggested in many papers).

You can see one such a work here:
https://papers.nips.cc/paper/7942-constrained-generation-of-semantically-valid-graphs-via-regularizing-variational-autoencoders.pdf. (section 3 and Eq 8 in particular)

Thank you!

Thanks for the information! I’m unfortunately not experienced enough in this area, but maybe @tom might know, if that’s possible. :slight_smile:

1 Like

So I’m not sure if I understand correctly what you have in mind, but a few quick comments that shape my intuition around this:

  • As you write, the Lagrange multiplier helps turn a constrained problem into an unconstrained one. While it does so, however, it turns the (say) minimization problem into a saddle point problem. (Because you want the min_{variables} max_{Lagrange multipliers} L, and whenever your constraint is not satisfied, you get infinity in the max part, so it will be irrelevant to the min (if an admissible point exists).)
  • If you look at regularization, typically, you’ll have a weight for the regularization term as a hyperparameter. Now, if you naïvely included that in your optimization while enforcing positivity, you would just drive that to 0, because, obviously no penalty is cheaper than a penalty.
  • You can think of tuning as “optimizing hyperparameters” (and the Bayesian crowd will have thoughts on integrating over hyperparameters). This can be done if you set an objective separate from the “inner” optimization objectives, e.g. (as for manual tuning) loss on a validation set. This often goes under the label “meta-learning”.
  • The perhaps best-studied regularization method is Tikhonov regularization (leading to “ridge regression”). As the Wikipedia page notes, the optimal regularization parameter is typically unknown. It does offer some heuristics / discussion of helpful interpretations.
  • Yarin Gal, in his thesis, has a section on optimizing the dropout probability in the context of variational Bayesian inference (Section 6.4 ELBO correlation with test log likelihood), where he reports mixed results and offers some hypothesis for possible causes. I must admit I did not follow whether further research exists on any of these hypothesis.

Best regards

Thomas

2 Likes

Thank you for the insight. I recently encountered this paper and it’s implementation in TF. Although they address a more general case where the constraints are not differentiable using a so called proxy Lagrangian, I was wondering if you have encountered similar PyTorch implementations (Not necessarily for non-differentiable constraints, but only general enough to work for non-convex DNNs).