How to check and read Confusion matrix?

If you just want to print it at the end, past the print statements at the desired place (e.g. after the training loop).

So, it will be like this right?

for epoch in range(epochs):
  
  running_loss = 0
  model.train()
  for images, labels in dataloader_train:
    
    #steps += 1
    images, labels = images.to(device), labels.to(device)
    
    optimizer.zero_grad()
    
    output = model.forward(images)
    conf_matrix = confusion_matrix(output, labels, conf_matrix)
    p = torch.nn.functional.softmax(output, dim=1)
    prediction = torch.argmax(p, dim=1)
    #loss = torch.nn.functional.nll_loss(torch.log(p), y)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
  #if steps % print_every == 0:
  valid_loss = 0
  accuracy = 0
  model.eval()
  #print(conf_matrix)(**NOT HERE RIGHT?**)
  for images, labels in dataloader_test:
    optimizer.zero_grad()
    with torch.no_grad():
       
      images, labels = images.to(device), labels.to(device)

      output = model.forward(images)
      p = torch.nn.functional.softmax(output, dim=1)
      prediction = torch.argmax(p, dim=1)
      loss = criterion(output, labels)
          
      valid_loss += loss.item()
          
      ps = torch.exp(output)
         
      top_p, top_class = ps.topk(1, dim = 1)
      equals = top_class == labels.view(*top_class.shape)
      accuracy += torch.mean(equals.type(torch.FloatTensor))
  
  print(conf_matrix)      
  print("Epoch: {}/{} " .format(epoch+1, epochs))
  print("Train loss: {:.4f}.. " .format(running_loss/len(dataloader_train)))
  print("Valid loss: {:.4f}.. " .format(valid_loss/len(dataloader_test)))
  print("Accuracy: {:.4f}.. " .format(accuracy/len(dataloader_test)))
  model.train()

Also, it looks like you are computing the training and validation stats together, which is a bad idea. Usually you would re-create the confusion matrix for each dataset.

So, do I have made two different functions for confusion matrix? If not then what should I do? I have stopped all the print statements in the def Confusion_matrix. Thanks.