Assert scalar.squeeze().ndim == 0, "scalar should be 0D" AssertionError: scalar should be 0D

Hi, I’m training a GNN model (encoder) in Pytorch geometric with a custom loss function and I’m facing an issue while doing the training iteration for my model. The code is as below:

def train(dataset, epochs, criterion, writer):
    train_loader = DataLoader(dataset, batch_size=1, collate_fn=dataset.collate_fn, shuffle=False)

    hidden_features = 16
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # build the model
    model = GNN(dataset.num_features, hidden_features=hidden_features)
    model =
    opt = torch.optim.Adam(model.parameters(), lr=0.0001)

    # train
    for epoch in range(1, epochs + 1):
        total_loss = 0
        for batch in train_loader:
            for data in batch:  

                # Zero your gradients for every batch!

                output = model(x, edge_index)

                loss = criterion.forward(output, y)


                total_loss += loss
                writer.add_scalar("loss", total_loss, epoch)

            print("Epoch {}. Loss: {:.4f}".format(
                epoch, loss))

    return model

can you please tell me how to fix this and if the training process is correct?

I guess the error is raised in writer.add_scalar("loss", total_loss, epoch) which expects a scalar input while total_loss seems to contain multiple values.

thanks, can you please tell me if the training process correct?

The code looks generally alright besides a few minor issues:

  • During training you would usually shuffle the dataset while you are specifying shuffle=False in the DataLoader.
  • .cuda() is not an inplace operation on tensors so you might need to use data = data.cuda() or data = in case data is a tensor.
  • You should not call the forward method directly but the object, so you might want to change criterion.forward(...) to criterion(...).
  • You are storing the entire computation graph in total_loss += loss since loss is still attached to it. Use total_loss += loss.detach() to avoid increasing the memory in each iteration.