Weights are not updating during backpropagation

Hi evreyone, I have an issue when running this coe

`class GraphSAGE(torch.nn.Module):
def init(self, in_feats: int, h_feats: int, num_classes: int, aggregator:list = [‘mean’, ‘mean’]):
super(GraphSAGE, self).init()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator[0])
self.conv2 = SAGEConv(h_feats, num_classes, aggregator[1])

def forward(self, g, in_feat):
    h = self.conv1(g, in_feat)
    h = F.relu(h)
    h = self.conv2(g, h)
    lg_softmax = torch.nn.LogSoftmax(dim = 1)
    return h, lg_softmax(h).argmax(1).type(torch.float32).requires_grad_() 

def train(g, model, epochs: int = 200, lr : float = 0.05, display : bool = True):

training_loss = []
train_accuracy = []
val_accuracy = []
test_accuracy = []
all_logits = []

best_val_acc = 0
best_test_acc = 0

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

features = g.ndata["feat"]
labels_data = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]


for epoch in range(epochs):

 
    print("\n", list(model.parameters()))
    # Forward pass
    
    logits, pred = model(g, features)
    print(pred)
    loss = criterion(pred, labels_data)
    training_loss.append(loss.item())
    all_logits.append(torch.squeeze(logits).clone().detach())

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    train_accuracy.append(torch.sum(pred[train_mask] == labels_data[train_mask]).item() / labels_data[train_mask].size(0))
    val_accuracy.append(torch.sum(pred[val_mask] == labels_data[val_mask]).item() / labels_data[val_mask].size(0))
    test_accuracy.append(torch.sum(pred[test_mask] == labels_data[test_mask]).item() / labels_data[test_mask].size(0))
    
    # Save the best validation accuracy and the corresponding test accuracy.
    if best_val_acc < val_accuracy[epoch]:
        best_val_acc = val_accuracy[epoch]
        best_test_acc = test_accuracy[epoch]
    
    if epoch % 50 == 0:
        print(
            "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                epoch, loss, val_accuracy[epoch], best_val_acc, test_accuracy[epoch], best_test_acc
            )
        )
        # print(pred, labels_data)
        
if display :
    plt.figure(figsize=(14, 6))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, len(training_loss)+1), training_loss)
    plt.xlabel("Epochs")
    plt.ylabel("Loss")

    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(train_accuracy)+1), train_accuracy, label="Training accuracy")
    plt.plot(range(1, len(val_accuracy)+1), val_accuracy, label="Validation accuracy")
    plt.plot(range(1, len(test_accuracy)+1), test_accuracy, label="Test accuracy")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")

    plt.tight_layout()
    plt.show()

return all_logits`

the weights are not updating and I think I know why. My guess is that when I use lg_softmax(h).argmax(1).type(torch.float32).requires_grad_() I detach the variables from the graph…
But I don’t know how to create a variable with a requires_grad = True without detaching the other variables.
I hope you can help me.

Thank you !

That is correct since argmax is not a differentiable operation. Calling .requires_grad_() on the result won’t re-attach it to the computation graph.