How to make the GCN output to be of binary type

Hello. I am very new in PyTorch graph Neural Network. I have multiple graphs and I want to do node classification. I divided the graphs to 80% of them for training and 10% for evaluation and 10% for testing. I tried to use GCN to classify nodes. To do that, I implemented the following code:

class GCN(torch.nn.Module):
def init(self, in_channels, hidden_channels, out_channels):
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
# Define the linear layer
self.linear = torch.nn.Linear(out_channels, out_channels)

def forward(self, x,edge_index,edge_weight):
    # x, edge_index = data.x, data.edge_index

    x = self.conv1(x, edge_index,edge_weight)
    x = F.relu(x)
    x = F.dropout(x,
    x = self.conv2(x, edge_index, edge_weight)
    # Apply the ReLU activation function
    x = F.relu(x)
    # Apply the linear layer
    x = self.linear(x)
    return F.log_softmax(x, dim=1)

device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

in_channels = gInfo.getNodeNumberOfFeatures() # Number of input features per node

hidden_channels = gInfo.getNodeNumberOfFeatures()+4 # Number of output features for the first convolutional layer

max_num_classes = 2 # Since the target output can be either 0 or 1

Initialize the GCN model with the maximum number of classes

model = GCN(in_channels, hidden_channels, max_num_classes).to(device)

Use Binary Cross Entropy Loss for scenarios where target has only one class

binary_loss_op = torch.nn.BCEWithLogitsLoss()

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

def train():

total_loss = 0
for data in train_loader:
    data =
    output = model(data.x, data.edge_index, data.edge_weight)
    # Calculate loss
    loss = binary_loss_op(output, data.y.squeeze(dim=1))
    total_loss += loss.item() * data.num_graphs
return total_loss / len(train_loader.dataset)

Now, here for the loss part, the output has 2 dimensions and the data.y is only 0 or 1. I want to know how can I output to become 0 or 1 so to be able to compare with data.y?