Hi , I am wondering what is the correct way to save and restore a model that is composed of two models.

These are my models

```
class TunedResNet(nn.Module):
"""
resnet with last fc layer replaced by one with output size = 10000
"""
def __init__(self):
super().__init__()
weights_v2 = ResNet50_Weights.IMAGENET1K_V2
self.resnet = resnet50(weights=weights_v2)
fc_in_features = self.resnet.fc.in_features
fc_out_features = NUM_CLASSES
self.resnet.fc = nn.Linear(fc_in_features, fc_out_features)
def forward(self, x):
return self.resnet(x)
class GeoNet(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
OrderedDict([
("fc1", nn.Linear(in_features=4, out_features=500)),
("relu1", nn.ReLU()),
("drop", nn.Dropout(p=0.2)),
("fc2", nn.Linear(in_features=500, out_features=1000)),
("relu2", nn.ReLU())]))
def forward(self, x):
return self.model(x)
class CombinedNet(nn.Module):
def __init__(self, resnet, geonet):
super().__init__()
self.resnet = resnet
# geo net
self.geonet = geonet
self.resnet_out_feats = self.resnet.resnet.fc.out_features
self.geonet_out_feats = self.geonet.model.fc2.out_features
self.fc = nn.Sequential(
nn.Linear(in_features=self.resnet_out_feats + self.geonet_out_feats, out_features=NUM_CLASSES))
def forward(self, x):
image = x["image"]
location = x["location"]
r = self.resnet(image)
g = self.geonet(location)
concat_tensor = torch.cat((r.view(-1, self.resnet_out_feats), g.view(-1, self.geonet_out_feats)), dim=1)
return self.fc(concat_tensor)
```

`CombinedNet`

is composed of two models - `TunedResNet`

and `GeoNet`

I read this tutorial and it saves each models `state_dict`

separately. However, when I look at the state_dict of the combined model, I think it contains all the parameters and I don’t need to save each models state dict separately. Can someone confirm my understanding . Thanks !

This is what I see when I print the parameters in the combined model

```
for k in combined_model.state_dict().keys():
print(k)
```

```
<SKIPPED> earlier layers
resnet.resnet.layer4.2.conv3.weight
resnet.resnet.layer4.2.bn3.weight
resnet.resnet.layer4.2.bn3.bias
resnet.resnet.layer4.2.bn3.running_mean
resnet.resnet.layer4.2.bn3.running_var
resnet.resnet.layer4.2.bn3.num_batches_tracked
resnet.resnet.fc.weight
resnet.resnet.fc.bias
geonet.model.fc1.weight
geonet.model.fc1.bias
geonet.model.fc2.weight
geonet.model.fc2.bias
fc.0.weight
fc.0.bias
```