Hi, I am trying to jit.script a module containing a parametrization as follow:
class SWSConv2d(nn.Conv2d):
r"""
2D Conv layer with Scaled Weight Standardization.
Characterizing signal propagation to close the performance gap in unnormalized ResNets
https://arxiv.org/abs/2101.08692
"""
def __init__(self,
in_channels: int, out_channels: int, kernel_size: int,
stride: int=1, padding: int=0, padding_mode: str='zeros', dilation=1, groups: int=1,
bias: bool=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, padding_mode=padding_mode, dilation=dilation, groups=groups, bias=bias)
self.register_parametrization()
def register_parametrization(self):
if not TP.is_parametrized(self, 'weight'):
TP.register_parametrization(self, 'weight', ScaledWeight2DStandardization(out_channels=self.weight.shape[0], use_gain=True, eps=1e-4))
def remove_parametrization(self):
if TP.is_parametrized(self, 'weight'):
TP.remove_parametrizations(self, 'weight', leave_parametrized=True)
model_jit = torch.jit.script(model) fails with the following error:
File "/Users/ganneheim/anaconda3/envs/pytorch1.9/lib/python3.7/site-packages/torch/jit/frontend.py", line 137, in <listcomp>
stmts = [build_stmt(ctx, s) for s in stmts]
File "/Users/ganneheim/anaconda3/envs/pytorch1.9/lib/python3.7/site-packages/torch/jit/frontend.py", line 330, in __call__
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: global variables aren't supported:
File "/Users/ganneheim/anaconda3/envs/pytorch1.9/lib/python3.7/site-packages/torch/nn/utils/parametrize.py", line 166
def get_parametrized(self) -> Tensor:
global _cache
~~~~~~ <--- HERE
parametrization = self.parametrizations[tensor_name]