Hi evreyone, I have an issue when running this coe
`class GraphSAGE(torch.nn.Module):
def init(self, in_feats: int, h_feats: int, num_classes: int, aggregator:list = [‘mean’, ‘mean’]):
super(GraphSAGE, self).init()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator[0])
self.conv2 = SAGEConv(h_feats, num_classes, aggregator[1])
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
lg_softmax = torch.nn.LogSoftmax(dim = 1)
return h, lg_softmax(h).argmax(1).type(torch.float32).requires_grad_()
def train(g, model, epochs: int = 200, lr : float = 0.05, display : bool = True):
training_loss = []
train_accuracy = []
val_accuracy = []
test_accuracy = []
all_logits = []
best_val_acc = 0
best_test_acc = 0
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
features = g.ndata["feat"]
labels_data = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
for epoch in range(epochs):
print("\n", list(model.parameters()))
# Forward pass
logits, pred = model(g, features)
print(pred)
loss = criterion(pred, labels_data)
training_loss.append(loss.item())
all_logits.append(torch.squeeze(logits).clone().detach())
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_accuracy.append(torch.sum(pred[train_mask] == labels_data[train_mask]).item() / labels_data[train_mask].size(0))
val_accuracy.append(torch.sum(pred[val_mask] == labels_data[val_mask]).item() / labels_data[val_mask].size(0))
test_accuracy.append(torch.sum(pred[test_mask] == labels_data[test_mask]).item() / labels_data[test_mask].size(0))
# Save the best validation accuracy and the corresponding test accuracy.
if best_val_acc < val_accuracy[epoch]:
best_val_acc = val_accuracy[epoch]
best_test_acc = test_accuracy[epoch]
if epoch % 50 == 0:
print(
"In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
epoch, loss, val_accuracy[epoch], best_val_acc, test_accuracy[epoch], best_test_acc
)
)
# print(pred, labels_data)
if display :
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(training_loss)+1), training_loss)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracy)+1), train_accuracy, label="Training accuracy")
plt.plot(range(1, len(val_accuracy)+1), val_accuracy, label="Validation accuracy")
plt.plot(range(1, len(test_accuracy)+1), test_accuracy, label="Test accuracy")
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.tight_layout()
plt.show()
return all_logits`
the weights are not updating and I think I know why. My guess is that when I use lg_softmax(h).argmax(1).type(torch.float32).requires_grad_()
I detach the variables from the graph…
But I don’t know how to create a variable with a requires_grad = True without detaching the other variables.
I hope you can help me.
Thank you !