Weight normalization of lazily initialized modules

Hi experts,
I am trying to use weight normalization (torch.nn.utils.parametrizations.weight_norm) on the SAGEConv.
So inside the SAGEConv class, I added the weight_norm call to each of the Linear() instances.
However, SAGEConv is lazily initialized with SAGEConv(-1,…). So I cannot do this because the parameters are uninitialized.
Would you know how I can make this work?
Regards,
David