Failed to script a module containing torch.nn.utils.parametrize.register_parametrization

Hi, I am trying to jit.script a module containing a parametrization as follow:

class SWSConv2d(nn.Conv2d):
    r"""
    2D Conv layer with Scaled Weight Standardization.

    Characterizing signal propagation to close the performance gap in unnormalized ResNets
    https://arxiv.org/abs/2101.08692
    """
    def __init__(self,
                 in_channels: int, out_channels: int, kernel_size: int,
                 stride: int=1, padding: int=0, padding_mode: str='zeros', dilation=1, groups: int=1,
                 bias: bool=True):

        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, padding_mode=padding_mode, dilation=dilation, groups=groups, bias=bias)

        self.register_parametrization()

    def register_parametrization(self):
        if not TP.is_parametrized(self, 'weight'):
            TP.register_parametrization(self, 'weight', ScaledWeight2DStandardization(out_channels=self.weight.shape[0], use_gain=True, eps=1e-4))

    def remove_parametrization(self):
        if TP.is_parametrized(self, 'weight'):
            TP.remove_parametrizations(self, 'weight', leave_parametrized=True)

model_jit = torch.jit.script(model) fails with the following error:

  File "/Users/ganneheim/anaconda3/envs/pytorch1.9/lib/python3.7/site-packages/torch/jit/frontend.py", line 137, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/Users/ganneheim/anaconda3/envs/pytorch1.9/lib/python3.7/site-packages/torch/jit/frontend.py", line 330, in __call__
    raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: global variables aren't supported:
  File "/Users/ganneheim/anaconda3/envs/pytorch1.9/lib/python3.7/site-packages/torch/nn/utils/parametrize.py", line 166
    def get_parametrized(self) -> Tensor:
        global _cache
        ~~~~~~ <--- HERE
    
        parametrization = self.parametrizations[tensor_name]