Avoid caching for parametrization

I would like to apply an automatic process to a non-learnable parameter (for instance L2 normalization).
pytorch parametrization was a good candidate:

parametrize.register_parametrization(
            self, "u", L2Normalize())

By doing so I can update my parameter u easily in a loop within a single learning step:

for _ in range(n_iter):
  self.u = myfunction(self.u)

This works well when the parametrization caching is deactivated, but by activating the with parametrize.cached(), any call to self.u gives back the first value (before updates within the loop).

This could be a bug since the assignment self.u = …, that calls the right_inverse method and updates the self.parametrizations.u.original, it should discard the cache.

But the usage done in my code of parametrization is not the classic one. Does anyone have an alternative solution that can both support multiple assignment and automatic normalization?

Otherwise, modification of PyTorch parametrizations could be:

  • to discard cache of the tensor when an assignment is done (as described before)
  • to provide a flag in the register_parametrization to prevent caching for specific parameters

Thanks for your help