Hello everyone, I need your help,

I created a Graph Neural Network using Pytorch Geometric and trained it in Google Colab since I don’t have GPUs available. Below you can see an example of my network:

class GIN1(torch.nn.Module):

“”" 1 Vs. 2 GIN “”"

```
def __init__(self, h):
super(GIN1, self).__init__()
dim_h_conv = h
dim_h_fc = dim_h_conv * 5
# Convolutional layers
self.conv1 = GINConv(Sequential(Linear(14, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv2 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv3 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv4 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv5 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
# Fully connected layers
self.lin1 = Linear(dim_h_fc, dim_h_fc)
self.lin2 = Linear(dim_h_fc, 4)
def forward(self, x, edge_index, batch):
h1 = self.conv1(x, edge_index)
h2 = self.conv2(h1, edge_index)
h3 = self.conv3(h2, edge_index)
h4 = self.conv4(h3, edge_index)
h5 = self.conv5(h4, edge_index)
h5 = F.dropout(h5, p=0.2, training=self.training)
# Graph level readout
h1 = global_add_pool(h1, batch)
h2 = global_add_pool(h2, batch)
h3 = global_add_pool(h3, batch)
h4 = global_add_pool(h4, batch)
h5 = global_add_pool(h5, batch)
# Concatenate graph embeddings
h = torch.cat((h1, h2, h3, h4, h5), dim=1)
# Classifier
h = self.lin1(h)
h = h.relu()
h = F.dropout(h, p=0.3, training=self.training)
h = self.lin2(h)
h = F.log_softmax(h, dim=1)
return h
```

After doing this I went to save the whole network with torch.save() function as a .pt file.

model = GIN1(h = 20)

torch.save(model, “gin1.pt”)

models = torch.load(‘gin1.pt’)

However, when I go to import the network into PyCharm I get the following error:

Can’t get attribute ‘GIN1’ on <module ‘

main’

Do you know how to solve it? Thank you all for the help!