Problems saving and loading a pre-trained network

Hello everyone, I need your help,
I created a Graph Neural Network using Pytorch Geometric and trained it in Google Colab since I don’t have GPUs available. Below you can see an example of my network:

class GIN1(torch.nn.Module):
“”" 1 Vs. 2 GIN “”"

def __init__(self, h):
    super(GIN1, self).__init__()
    dim_h_conv = h
    dim_h_fc = dim_h_conv * 5
    # Convolutional layers
    self.conv1 = GINConv(Sequential(Linear(14, dim_h_conv),
                                    BatchNorm1d(dim_h_conv), ReLU(),
                                    Linear(dim_h_conv, dim_h_conv), ReLU()))
    self.conv2 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
                                    BatchNorm1d(dim_h_conv), ReLU(),
                                    Linear(dim_h_conv, dim_h_conv), ReLU()))
    self.conv3 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
                                    BatchNorm1d(dim_h_conv), ReLU(),
                                    Linear(dim_h_conv, dim_h_conv), ReLU()))
    self.conv4 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
                                    BatchNorm1d(dim_h_conv), ReLU(),
                                    Linear(dim_h_conv, dim_h_conv), ReLU()))
    self.conv5 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
                                    BatchNorm1d(dim_h_conv), ReLU(),
                                    Linear(dim_h_conv, dim_h_conv), ReLU()))
    # Fully connected layers
    self.lin1 = Linear(dim_h_fc, dim_h_fc)
    self.lin2 = Linear(dim_h_fc, 4)

def forward(self, x, edge_index, batch):
    h1 = self.conv1(x, edge_index)
    h2 = self.conv2(h1, edge_index)
    h3 = self.conv3(h2, edge_index)
    h4 = self.conv4(h3, edge_index)
    h5 = self.conv5(h4, edge_index)
    h5 = F.dropout(h5, p=0.2, training=self.training)
    # Graph level readout
    h1 = global_add_pool(h1, batch)
    h2 = global_add_pool(h2, batch)
    h3 = global_add_pool(h3, batch)
    h4 = global_add_pool(h4, batch)
    h5 = global_add_pool(h5, batch)
    # Concatenate graph embeddings
    h = torch.cat((h1, h2, h3, h4, h5), dim=1)
    # Classifier
    h = self.lin1(h)
    h = h.relu()
    h = F.dropout(h, p=0.3, training=self.training)
    h = self.lin2(h)
    h = F.log_softmax(h, dim=1)
    return h

After doing this I went to save the whole network with torch.save() function as a .pt file.

model = GIN1(h = 20)
torch.save(model, “gin1.pt”)
models = torch.load(‘gin1.pt’)

However, when I go to import the network into PyCharm I get the following error:

Can’t get attribute ‘GIN1’ on <module ‘main

Do you know how to solve it? Thank you all for the help!

Saving the entire model can easily break since the environment loading the model would need to use the same file structure.
Would it be possible to store the model.state_dict() instead and later loading the state_dict into the model object, instead?

1 Like

@ptrblck Thank you for your contribution! I solved the problem by rewriting the GIN1 class in PyCharm and going to save only the state_dict from Colab, so it works without any problems, thank you very much. But what is the point of allowing to save the whole model if it then becomes difficult to load it???

Saving the model could work but since I’ve seen it breaking a lot of times I would not recommend using it and might not have allowed this workflow at all.
However, I’m sure there might still be valid use cases and dropping this method now could be a large backwards compatibility break.