Hello,
I am using ModuleList to construct a model consisting of CustomSubModule.
class CustomSubModule(T.nn.Module):
def __init__(self, fan_in, fan_out, weight, bias, alpha, beta):
super().__init__()
self.no_bias_layer = T.nn.Linear(fan_in, fan_out, bias=False)
self.no_bias_layer.weight = T.nn.Parameter(weight)
self.bias = T.nn.Parameter(bias)
self.register_buffer('alpha', alpha)
self.register_buffer('beta', beta)
class CustomModule(T.nn.Module):
def __init__(self):
super().__init__()
self.layers = T.nn.ModuleList()
def add_layer(self, layer: CustomSubModule):
self.layers.append(layer)
However, I can’t load the state_dict using the load_state_dict function
custom_sub_module = CustomSubModule(10, 10, T.rand(10, 10), T.rand(10, ),
T.tensor([5]), T.tensor([4]))
costum_module = CustomModule()
costum_module.add_layer(custom_sub_module)
T.save(costum_module.state_dict(), '/tmp/test_model.pt')
loaded_module = CustomModule()
loaded_module.load_state_dict(T.load('/tmp/test_model.pt'))
Traceback (most recent call last):
File "tests/save_dict_example.py", line 36, in <module>
loaded_module.load_state_dict(T.load('/tmp/test_model.pt'))
File ".venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CustomModule:
Unexpected key(s) in state_dict: "layers.0.bias", "layers.0.alpha", "layers.0.beta", "layers.0.no_bias_layer.weight".
Is there any way to overcome this?