Pytorch training loop bug in updating loss

This is my training loop and my loss is not updating.

for epoch_i in range(5):
  total_loss = 0
  print(epoch_i,total_loss)

  model.train()
    
  for step, batch in enumerate(train_data,0):

      train_x, train_y = tuple(t.to(device)for t in batch)

      model.zero_grad()

      logits=model(train_x)

      loss_fn= nn.CrossEntropyLoss()

      #print(train_y.shape,"y",train_x.shape,"x")

      loss= loss_fn(logits,train_y)

      total_loss=loss+total_loss
    
      loss.backward()

      optimizer.step()

This is my output :
0 0
1 0
2 0
3 0
4 0
The input dimension of my tensor is :
train_x= torch.Size([32, 300])
train_y= [torch.Size([32, 20])

May be I am missing something. Can someone spot the error why the loss is not updating and

move print line to the end of for loop on epoch. this should fix 0 total loss
and you sould change some other lines for the code to be bug free.

  1. total_loss = loss.item() + total_loss
  2. move loss_fn definition out of loop body