Hello everyone, I need your help,
I created a Graph Neural Network using Pytorch Geometric and trained it in Google Colab since I don’t have GPUs available. Below you can see an example of my network:
class GIN1(torch.nn.Module):
“”" 1 Vs. 2 GIN “”"
def __init__(self, h):
super(GIN1, self).__init__()
dim_h_conv = h
dim_h_fc = dim_h_conv * 5
# Convolutional layers
self.conv1 = GINConv(Sequential(Linear(14, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv2 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv3 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv4 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv5 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
# Fully connected layers
self.lin1 = Linear(dim_h_fc, dim_h_fc)
self.lin2 = Linear(dim_h_fc, 4)
def forward(self, x, edge_index, batch):
h1 = self.conv1(x, edge_index)
h2 = self.conv2(h1, edge_index)
h3 = self.conv3(h2, edge_index)
h4 = self.conv4(h3, edge_index)
h5 = self.conv5(h4, edge_index)
h5 = F.dropout(h5, p=0.2, training=self.training)
# Graph level readout
h1 = global_add_pool(h1, batch)
h2 = global_add_pool(h2, batch)
h3 = global_add_pool(h3, batch)
h4 = global_add_pool(h4, batch)
h5 = global_add_pool(h5, batch)
# Concatenate graph embeddings
h = torch.cat((h1, h2, h3, h4, h5), dim=1)
# Classifier
h = self.lin1(h)
h = h.relu()
h = F.dropout(h, p=0.3, training=self.training)
h = self.lin2(h)
h = F.log_softmax(h, dim=1)
return h
After doing this I went to save the whole network with torch.save() function as a .pt file.
model = GIN1(h = 20)
torch.save(model, “gin1.pt”)
models = torch.load(‘gin1.pt’)
However, when I go to import the network into PyCharm I get the following error:
Can’t get attribute ‘GIN1’ on <module ‘main’
Do you know how to solve it? Thank you all for the help!