Torch.argmax returns a tensor containing all zeros?

yes,sure.
I trained my dataset by following network that must extract 21 features. but, I’m not sure that it works correct or not.
I guess the zeros value of arg max is because of my model. when I return out2 in model, I get zeroes value, but when I change network like this, the argmax have nonzeros values!!!
when I use this model, the accuracy will be very low!!!

for epoch in tqdm(range(1, num_epochs + 1)):
    # start_time = time.time()
    num_examples =0
    scheduler.step()
    correct_pred = 0
    lr = scheduler.get_lr()[0]

    model.eval()
    train_loss_total = 0.0
    num_steps = 0
    test_loss_total = 0
    num = 0

    ### Training
    for i, batch in enumerate(test1_loader):

        input_samples, gt_samples = batch[2], batch[3]
        var_input =input_samples
        var_gt1 = gt_samples
        optimizer.zero_grad()

        preds = model(var_input)
        var_gt = var_gt1.reshape(8,64,64,64)
        var_gt = var_gt.long()

        # print(preds.shape)
        loss1 = criterion(preds, var_gt)   
        # loss.backward(retain_graph=True)
        loss1.backward()
        optimizer.step()
        test1_loss_total += loss1.item()
        num_steps += 1

        _, predicted_test = torch.max(preds, 1)
        total_test1 += var_gt.nelement()
        correct_test1 += predicted_test.eq(var_gt).sum().item()


        if epoch == num_epochs:
            
     
            print('***************************************************')
            predslist.append(preds)
            var_gtlist.append(var_gt1)
            print('done')
 
    test1_accuracy = 100 * correct_test1 / total_test1
    print('Epoch {}, test Loss: {:.3f}'.format(epoch, test1_loss_total / num_steps),
          "Testing Accuracy: %d %%" % (test1_accuracy))
    print("finished training")

and:
I calculate argmax for dim0 that it is for 21 classes.

for ii in predslist:
  for j in range(0, batch_size):
      # print(ii[j])
      print('ii[j].shape:',ii[j].shape)
      om = torch.argmax(ii[j].squeeze(), dim=0).detach().cpu().numpy()
      print ('np.unique(om):' ,np.unique(om)) 
      print('om.shape:',om.shape)
      print('************************************************************')

result:

************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  5  7  9 10 12 13 14 15 16 18 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  7  9 10 12 13 14 15 16 18 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  2  5  7  9 10 11 13 14 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  5  7 10 11 13 14 15 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  5  7 10 11 12 13 14 15 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 1  2  5  7 10 13 14 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  2  4  5  7 10 11 12 13 14 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 1  2  5  7 10 11 13 14 16 18 20]
om.shape: (64, 64, 64)

and this is my model:

class Myconv(nn.Module):
    def __init__(self):
        super(Myconv, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=10, kernel_size=(3, 3, 3),stride=1, padding=1)
        self.norm1 =nn.BatchNorm3d(num_features = 10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.max1 = nn.MaxPool3d(kernel_size=(3, 3 ,3), stride=1, padding=1)


        self.conv2 = nn.Conv3d(in_channels=10, out_channels=21, kernel_size=(3, 3, 3),
                               stride=1, padding=1 )
        self.norm2 =nn.BatchNorm3d(num_features = 21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.max2 = nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1)

        self.conv3 = nn.Conv3d(in_channels=21, out_channels=21, kernel_size=(3, 3, 3),stride=1, padding=1)
        self.norm3 =nn.BatchNorm3d(num_features = 21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.max3 = nn.MaxPool3d(kernel_size=(3, 3 ,3), stride=1, padding=1)

    def forward(self, x):
        out1 = self.conv1(x)
        out1 = self.norm1(out1)
        out1 = F1.relu(out1)
        out1 = self.max1(out1)

        out2 = self.conv2(out1)
        out2 = self.norm2(out2)
        out2 = F1.relu(out2)
        out2 = self.max2(out2)

        out3 = self.conv3(out2)
        out3 = self.norm3(out3)
        out3 = F1.relu(out3)
        out3 = self.max3(out3)
        return out3

when I edit paremeters of model and get more accuracy (99%), I get zeroes argmax!!!