Initializating parameter with registered parametrization

Hello, I’m having trouble initializing a module that contains parameters with registered parametrization.

My understanding:
The initialization functions in torch.nn.init modify the tensors in place. When using parametrization the attribute becomes a property and the tensor returned by the getter method is the output of the parametrization’s forward call: modifying this tensor in place doesn’t modify the original tensor values and doesn’t trigger the call to right_inverse.

A trivial examples with a ReLU parametrization applied to the weight of a Linear layer:

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

class ReLUParametrization(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        return torch.nn.functional.relu(x)
    
    def right_inverse(self, x: torch.Tensor):
        return torch.clamp(x, 0)

class MyModule(torch.nn.Linear):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__(in_features, out_features, bias=False, device=None, dtype=None)
        torch.nn.utils.parametrize.register_parametrization(self, "weight", ReLUParametrization())

def initialize(module: torch.nn.Module):
    if isinstance(module, MyModule):
        torch.nn.init.uniform_(module.weight, 0, 1)
if __name__ == "__main__":
    my_module = MyModule(3, 1)
    print(my_module.weight)
    my_module.apply(initialize)
    print(my_module.weight)

# tensor([[0.2571, 0.0000, 0.3429]], grad_fn=<ReluBackward0>)
# tensor([[0.2571, 0.0000, 0.3429]], grad_fn=<ReluBackward0>)

Could you please explain the recommended way to initialize parameters with registered parametrization ?

I’m unsure if I understand your use case correctly, but you could try to initialize the original parameter via module.parametrizations.weight.original.

Thank you for the help, I’ll look into that!