Enforcing parameter constraints in a Module

I have a custom Module that is a reusable layer, and it needs to constrain its weight parameters to between 0 and 1. The simplest and most obvious way to do this is to clip them after every optimiser step. I’m not interested in techniques that try to enforce them by guiding them with a modified loss, as it’s mathematically invalid for them to ever go out of that range.

I can make this work by calling a constraint function after optimiser.step() in my training loop. However, this is an implementation detail of the module, and I don’t want every training loop for any model that internally uses this module to have to remember to call this function in order to work. The responsibility for this constraint should be contained within the module itself.

I’ve not been able to find any hooks that fire after an update occurs, so I’m not sure how to go about enforcing this. Is there a nice way to do this?

The best I’ve come up with is enforcing it in forward, but that’s wasteful if it hasn’t changed since the last run.

Another approach would be to apply some kind of normalisation in the forward pass, but that feels like a hacky approach when all I need to do is constrain the values in a tensor. I don’t want to add extra complexity to the graph.