If a module contains a dictionary which has other two modules as following, can I get parameters of model_dict[‘model1’] and model_dict[‘model2’] with outer_network.state_dict()? Or is there any solutions to get parameters of modules from a dictionary?
class inner_network_1(nn.Module):
def __init__(self):
super(inner_network_1, self).__init__()
self.inner_fc = nn.Linear(20, 30)
class inner_network_2(nn.Module):
def __init__(self):
super(inner_network_2, self).__init__()
self.inner_fc = nn.Linear(20, 30)
class outer_network(nn.Module):
def __init__(self):
super(outer_network, self).__init__()
self.fc = nn.Linear(10, 20)
model_dict = {}
model_dict['model1'] = inner_network_1()
model_dict['model2'] = inner_network_2()
do you mean something like this,
class outer_network(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 20)
self.inner_network_1 = nn.ModuleDict([['inner_fc', nn.Linear(20, 30)]])
self.inner_network_2 = nn.ModuleDict([['inner_fc', nn.Linear(20, 30)]])
model = outer_network()
model.state_dict()
or something like this,
class inner_network_1(nn.Module):
def __init__(self):
super(inner_network_1, self).__init__()
self.inner_fc = nn.Linear(20, 30)
class inner_network_2(nn.Module):
def __init__(self):
super(inner_network_2, self).__init__()
self.inner_fc = nn.Linear(20, 30)
class outer_network(nn.ModuleDict):
def __init__(self):
super(outer_network, self).__init__()
self.fc = nn.Linear(10, 20)
self['model1'] = inner_network_1()
self['model2'] = inner_network_2()
model = outer_network()
model.state_dict()
Thank you for these two examples! I didn’t know ModuleDict, which is exactly what I need!