Hello, I’m having trouble initializing a module that contains parameters with registered parametrization.
My understanding:
The initialization functions in torch.nn.init
modify the tensors in place. When using parametrization the attribute becomes a property and the tensor returned by the getter method is the output of the parametrization’s forward
call: modifying this tensor in place doesn’t modify the original tensor values and doesn’t trigger the call to right_inverse
.
A trivial examples with a ReLU parametrization applied to the weight of a Linear layer:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
class ReLUParametrization(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.nn.functional.relu(x)
def right_inverse(self, x: torch.Tensor):
return torch.clamp(x, 0)
class MyModule(torch.nn.Linear):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__(in_features, out_features, bias=False, device=None, dtype=None)
torch.nn.utils.parametrize.register_parametrization(self, "weight", ReLUParametrization())
def initialize(module: torch.nn.Module):
if isinstance(module, MyModule):
torch.nn.init.uniform_(module.weight, 0, 1)
if __name__ == "__main__":
my_module = MyModule(3, 1)
print(my_module.weight)
my_module.apply(initialize)
print(my_module.weight)
# tensor([[0.2571, 0.0000, 0.3429]], grad_fn=<ReluBackward0>)
# tensor([[0.2571, 0.0000, 0.3429]], grad_fn=<ReluBackward0>)
Could you please explain the recommended way to initialize parameters with registered parametrization ?