Unable to load_state_dict when ModuleList is used

Hello,

I am using ModuleList to construct a model consisting of CustomSubModule.


class CustomSubModule(T.nn.Module):
    def __init__(self, fan_in, fan_out, weight, bias, alpha, beta):
        super().__init__()

        self.no_bias_layer = T.nn.Linear(fan_in, fan_out, bias=False)

        self.no_bias_layer.weight = T.nn.Parameter(weight)

        self.bias = T.nn.Parameter(bias)

        self.register_buffer('alpha', alpha)
        self.register_buffer('beta', beta)


class CustomModule(T.nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = T.nn.ModuleList()

    def add_layer(self, layer: CustomSubModule):
        self.layers.append(layer)

However, I can’t load the state_dict using the load_state_dict function

custom_sub_module = CustomSubModule(10, 10, T.rand(10, 10), T.rand(10, ),
                                    T.tensor([5]), T.tensor([4]))

costum_module = CustomModule()
costum_module.add_layer(custom_sub_module)
T.save(costum_module.state_dict(), '/tmp/test_model.pt')

loaded_module = CustomModule()
loaded_module.load_state_dict(T.load('/tmp/test_model.pt'))
Traceback (most recent call last):
  File "tests/save_dict_example.py", line 36, in <module>
    loaded_module.load_state_dict(T.load('/tmp/test_model.pt'))
  File ".venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CustomModule:
	Unexpected key(s) in state_dict: "layers.0.bias", "layers.0.alpha", "layers.0.beta", "layers.0.no_bias_layer.weight". 

Is there any way to overcome this?

Should you not call add_layer on the loaded model as well. I think the layers is empty. Can you try maybe something like

loaded_module = CustomModule()
loaded_module.add_layer(CustomSubModule(10, 10, T.rand(10, 10), T.rand(10, ),
                                    T.tensor([5]), T.tensor([4])))
loaded_module.load_state_dict(T.load('/tmp/test_model.pt'))