Attempted to use an uninitialized parameter in <built-in method empty_like of type object at 0x00007FFF5B826810>. This error happens when you are using a `LazyModule` or explicitly manipulating `torch.nn.parameter.UninitializedParameter` objects

I encountered an issue when using GradSampleModule to wrap my own module. The original module is torch_geometric.nn.Linear. Using this Linear with lazy initialization works fine, but after I inherited from Linear and implemented my own MaskedLinear, I started getting the following error. I’m not sure how to resolve it.

If the problem is related to lazy initialization, why didn’t I have this issue when using the original Linear?

Below are the traceback:

Traceback (most recent call last):
  File "D:\pycharm\PyCharm 2024.1.3\plugins\python\helpers\pydev\pydevd.py", line 1546, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "D:\pycharm\PyCharm 2024.1.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "D:\deep_learning\rPDP-GAP - test4\experiments\fedavg_rpdp.py", line 45, in <module>
    metrics = method.fit()
  File "D:\deep_learning\rPDP-GAP - test4\core\methods\node\gap\gap_ndp.py", line 111, in fit
    self.calibrate()
  File "D:\deep_learning\rPDP-GAP - test4\core\methods\node\gap\gap_ndp.py", line 104, in calibrate
    self._classifier[i] = self.classifier_noisy_sgd.prepare_module(self._classifier[i])
  File "D:\deep_learning\rPDP-GAP - test4\core\privacy\algorithms\noisy_sgd.py", line 52, in prepare_module
    GradSampleModule(module).register_backward_hook(forbid_accumulation_hook)
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\opacus\grad_sample\grad_sample_module.py", line 140, in __init__
    self.add_hooks(
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\opacus\grad_sample\grad_sample_module.py", line 186, in add_hooks
    prepare_layer(module, batch_first=batch_first)
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\opacus\grad_sample\functorch.py", line 20, in prepare_layer
    flayer, _ = make_functional(layer)
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\functorch\_src\make_functional.py", line 380, in make_functional
    return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\functorch\_src\make_functional.py", line 303, in _create_from
    params, param_names, names_map = extract_weights(model_copy)
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\functorch\_src\make_functional.py", line 114, in extract_weights
    return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter)
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\functorch\_src\make_functional.py", line 94, in _extract_members
    memo[p] = subclass(torch.empty_like(p, device='meta'))
  File "C:\Users\MSI\.conda\envs\rpdp39-gap\lib\site-packages\torch\nn\parameter.py", line 144, in __torch_function__
    raise ValueError(
ValueError: Attempted to use an uninitialized parameter in <built-in method empty_like of type object at 0x00007FFF5B826810>. This error happens when you are using a `LazyModule` or explicitly manipulating `torch.nn.parameter.UninitializedParameter` objects. When using LazyModules Call `forward` with a dummy batch to initialize the parameters before calling torch functions

This is the line went wrong:

GradSampleModule(module).register_backward_hook(forbid_accumulation_hook)

the module is a multimlp originally of linear layers,now i simply change it to my own maskedlinear layers.

my module is below:
from torch_geometric.nn import Linear
class MaskedLinear(Linear):
    def __init__(self, in_channels,out_channels, bias=True):
        super().__init__(in_channels, out_channels, bias=bias)
        self.mask = nn.parameter.UninitializedParameter()

    @torch.no_grad()
    def initialize_parameters(self, module, input):
        if is_uninitialized_parameter(self.weight):
            self.in_channels = input[0].size(-1)
            self.weight.materialize((self.out_channels, self.in_channels))
            self.mask.materialize((self.out_channels, self.in_channels))
            self.reset_parameters()
        self._hook.remove()
        delattr(self, '_hook')

    
    def forward(self, x):
        return F.linear(x, self.weight * self.mask, self.bias)