Torch.argmax returns a tensor containing all zeros?

Ive trained my segmentation network using a U-Net model. I want to now visualize the result from the network output.

Theshape of output is [1,2,256,256]. One channel for background and the other for foreground. Here is what I get for print(output.detach().squeeze()):

tensor([[[ 6.1109,  6.1109,  6.1126,  ...,  6.1580,  6.1317,  6.1317],
         [ 6.1109,  6.1109,  6.1126,  ...,  6.1580,  6.1317,  6.1317],
         [ 6.1125,  6.1125,  6.1139,  ...,  6.1281,  6.1075,  6.1075],
         ...,
         [ 8.3482,  8.3482,  8.3518,  ...,  8.7196,  8.7172,  8.7172],
         [ 8.3446,  8.3446,  8.3485,  ...,  8.7216,  8.7194,  8.7194],
         [ 8.3446,  8.3446,  8.3485,  ...,  8.7216,  8.7194,  8.7194]],

        [[-6.8949, -6.8949, -6.8821,  ..., -7.1064, -7.0877, -7.0877],
         [-6.8949, -6.8949, -6.8821,  ..., -7.1064, -7.0877, -7.0877],
         [-6.9068, -6.9068, -6.8939,  ..., -7.0928, -7.0780, -7.0780],
         ...,
         [-8.3679, -8.3679, -8.3719,  ..., -8.7457, -8.7435, -8.7435],
         [-8.3656, -8.3656, -8.3700,  ..., -8.7487, -8.7467, -8.7467],
         [-8.3656, -8.3656, -8.3700,  ..., -8.7487, -8.7467, -8.7467]]])

To get the predictions from this output I did:

predicted = torch.argmax(output.detach(), 1).squeeze(0)

But this gives a tensor with all zeros??

I then write this image along with the respective label using:

f_ = str(idx) + '_predicted.jpg'
io.imwrite(os.path.join(output_folder, f_), predicted.float().cpu().numpy())

f_ = str(idx) + '_target.jpg'
io.imwrite(os.path.join(output_folder, f_), labels.squeeze().float().cpu().numpy())

The shape of predicted and label(after squeeze()) is [256 256].

Label is being written fine as can be seen:

0_target

But I am only getting black for predicted:

0_predicted

Am I doing this right? How else am I supposed to visualize the network output?

Thanks

Strange thing is np.amax(predicted.numpy()) gives me 0. So how is there no maximum value recorded in predicted?

I think the maximal values might be in channel0. In your example output tensor, it looks like channel0 has positive values, while channel1 seems to only have negative values.
At least in the numbers you’ve posted.

Could this be the case, i.e. your model just predicts background for all pixels?

Hi @ptrblck

Even in that case I should get some output from argmax instead of a zero tensor. The way I am visualizing the outputs here was suggested by you in my previous question: Understanding how to label/target tensors for 3D volumes

And this seems to work fine for 3D volumes. But here I am using 2D images. But the concept should be the same and it should work. As I have labeled my label tensor during training such that channel0 represents background and channel1 represents my foreground.

So shouldnt I be looking for the maximum values in dim=1 ?

Yes your approach is right and I still think that the zero tensor is a valid answer, if your model just overfits to class0.

Have a look at this example:

x = torch.cat((torch.ones(1, 1, 5, 5), torch.zeros(1, 1, 5, 5)), 1)
print(x)
> tensor([[[[1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]],

           [[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.]]]])
print(torch.argmax(x, 1))
> tensor([[[0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0]]])

A zero tensor is in this case the right answer, since all values in channel0 are larger than in channel1.
If we manipulate a certain value in channel1, you’ll also get the right answer:

x[0, 1, 0, 0] = 2
print(torch.argmax(x, 1))
> tensor([[[1, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0]]])
1 Like

Hi,
I want to get argmax of my preds in batch.when I print argmax in model, It return values, but when I save the perds in list, it return zeros for all elements in list. I don’t know why it is happend, how can I get values in list?

my code is:

for epoch in tqdm(range(1, num_epochs + 1)):
    # start_time = time.time()
    num_examples =0
    scheduler.step()
    correct_pred = 0
    lr = scheduler.get_lr()[0]

    model.eval()
    train_loss_total = 0.0
    num_steps = 0
    test_loss_total = 0
    num = 0

    ### Training
    for i, batch in enumerate(test1_loader):

        input_samples, gt_samples = batch[2], batch[3]
        var_input =input_samples
        var_gt = gt_samples
        optimizer.zero_grad()

        preds = model(var_input)
        var_gt = var_gt.reshape(8,64,64,64)
        var_gt = var_gt.long()

        # print(preds.shape)
        loss1 = criterion(preds, var_gt)   
        # loss.backward(retain_graph=True)
        loss1.backward()
        optimizer.step()
        test1_loss_total += loss1.item()
        num_steps += 1

        _, predicted_test = torch.max(preds.data, 1)
        total_test1 += var_gt.nelement()
        correct_test1 += predicted_test.eq(var_gt.data).sum().item()
        print('************************************************************************')

        print(preds.shape)
        _, v = torch.max(preds, dim=0)
        print(v.unique())
        _, v1 = torch.max(var_gt, dim=0)
        print('   v1     ',v1.unique())


        print('----------------')

        if epoch == num_epochs:
           predslist.append(preds)
           var_gtlist.append(var_gt)
           print('done')

    test1_accuracy = 100 * correct_test1 / total_test1
    print('Epoch {}, test Loss: {:.3f}'.format(epoch, test1_loss_total / num_steps),
          "Testing Accuracy: %d %%" % (test1_accuracy))

    print("finished training")

and result is:

Epoch 4, test Loss: 0.148 Testing Accuracy: 99 %
finished training
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([2, 3, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([0, 3, 7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([2, 5, 7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
************************************************************************
torch.Size([8, 21, 64, 64, 64])
100%|██████████| 5/5 [21:29<00:00, 257.86s/it]tensor([0, 1, 2, 3, 4, 5, 6, 7])
   v1      tensor([7])
----------------
done
Epoch 5, test Loss: 0.185 Testing Accuracy: 99 %
finished training

my code for list is:

for ii in predslist:
  for j in range(0, batch_size):
      om = torch.argmax(ii[j].squeeze(), dim=0)
      print (om.unique()) 

that it return zeros.

To get the class predictions torch.argmax or torch.max should be called with the dim argument specifying the “class dimension”, i.e. the dimension which is set to the number of classes and contains the logits for each class.

In your training code you are using dim=0. Are you sure the class dimension is in dim0, not the batch dimension?

The last code snippet iterates the predslist and calls argmax again on dim=0, which will return zeros for a single sample in the batch.

For a vanilla classification use case your output would have the shape [batch_size, nb_classes] and you should therefore call torch.argmax(output, dim=1) as shown in my example.

PS: the usage of .data is not recommended, as it might have unwanted side effects.

thanks for your reply.

yes, it is for batch dimension, I made mistake.
Size of each element of predslist is torch.Size([21, 64, 64, 64]).I get argmax of dim=0 that it returns just zeros.
i wrote this code that shows values of each element of predslist:

for ii in predslist:
  for j in range(0, batch_size):
      print(ii[j])
      print('ii[j].shape:',ii[j].shape)
      om = torch.argmax(ii[j].squeeze(), dim=0).detach().cpu().numpy()
      print ('np.unique(om):' ,np.unique(om)) 
      print('om.shape:',om.shape)
      print('************************************************************')

and some of outputs is:

ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [0]
om.shape: (64, 64, 64)
************************************************************
tensor([[[[ 42.3527,  42.4556,  42.4556,  ...,  30.8631,  29.5867,  28.1223],
          [ 42.4102,  42.9640,  42.9640,  ...,  30.8631,  29.5867,  28.1223],
          [ 42.4102,  42.9640,  42.9640,  ...,  30.8631,  29.5867,  28.1223],
          ...,
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783],
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783],
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783]],

         [[ 44.3158,  44.3947,  44.3947,  ...,  31.4897,  31.1392,  30.3788],
          [ 44.5054,  44.7882,  44.7882,  ...,  31.6356,  31.1392,  30.3788],
          [ 44.5054,  44.7882,  44.7882,  ...,  31.6356,  31.1392,  30.3788],
          ...,
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783],
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783],
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783]],

         [[ 45.5911,  45.5911,  45.5911,  ...,  32.1989,  32.1989,  31.6583],
          [ 45.7082,  45.7082,  45.7082,  ...,  32.5894,  32.4413,  31.6583],
          [ 45.7082,  45.7082,  45.7082,  ...,  32.6178,  32.4413,  31.6583],
          ...,
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783],
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783],
          [ 48.3376,  48.3376,  48.3376,  ...,  48.6783,  48.6783,  48.6783]],

         ...,

         [[138.6365, 140.9799, 142.6059,  ...,  58.8695,  58.8695,  58.8695],
          [144.4590, 144.4590, 144.4590,  ...,  58.8695,  58.8695,  58.8695],
          [144.4590, 144.4590, 144.4590,  ...,  58.8695,  58.8695,  58.8695],
          ...,
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309],
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309],
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309]],

         [[ 82.9900,  83.1118,  83.7028,  ...,  50.3198,  50.3198,  50.3198],
          [ 86.4008,  86.4008,  86.4008,  ...,  50.3198,  50.3198,  50.3198],
          [ 86.4008,  86.4008,  86.4008,  ...,  50.3198,  50.3198,  50.3198],
          ...,
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309],
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309],
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309]],

         [[ 46.8168,  46.8168,  46.8168,  ...,  47.2092,  47.2092,  47.2092],
          [ 46.8168,  46.8168,  46.8168,  ...,  47.2092,  47.2092,  47.2092],
          [ 46.8168,  46.8168,  46.8168,  ...,  47.2092,  47.2092,  47.2092],
          ...,
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309],
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309],
          [ 47.3822,  47.3822,  47.3822,  ...,  48.2309,  48.2309,  48.2309]]],


        [[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         ...,

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]],


        [[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         ...,

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]],


        ...,


        [[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         ...,

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]],


        [[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         ...,

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]],


        [[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         ...,

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]]],
       grad_fn=<SelectBackward>)

ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [0]
om.shape: (64, 64, 64)
************************************************************

As you see predslist has values, but the argmax in dim0 is zeros.
I want to decode preds and I must have nonzeros values.
How can I get nonzeros values for argmax?
The shape of output of argmax should be (64,64,64).Did I make a mistake somewhere in my code?
Also, If I print argmax for dim=1,2,3, it has nonzeros values for elements of predslist !!!

I’m still unsure how you are calculating the predictions.
Could you post the corrected code again and explain what each dimension means, please?

1 Like

yes,sure.
I trained my dataset by following network that must extract 21 features. but, I’m not sure that it works correct or not.
I guess the zeros value of arg max is because of my model. when I return out2 in model, I get zeroes value, but when I change network like this, the argmax have nonzeros values!!!
when I use this model, the accuracy will be very low!!!

for epoch in tqdm(range(1, num_epochs + 1)):
    # start_time = time.time()
    num_examples =0
    scheduler.step()
    correct_pred = 0
    lr = scheduler.get_lr()[0]

    model.eval()
    train_loss_total = 0.0
    num_steps = 0
    test_loss_total = 0
    num = 0

    ### Training
    for i, batch in enumerate(test1_loader):

        input_samples, gt_samples = batch[2], batch[3]
        var_input =input_samples
        var_gt1 = gt_samples
        optimizer.zero_grad()

        preds = model(var_input)
        var_gt = var_gt1.reshape(8,64,64,64)
        var_gt = var_gt.long()

        # print(preds.shape)
        loss1 = criterion(preds, var_gt)   
        # loss.backward(retain_graph=True)
        loss1.backward()
        optimizer.step()
        test1_loss_total += loss1.item()
        num_steps += 1

        _, predicted_test = torch.max(preds, 1)
        total_test1 += var_gt.nelement()
        correct_test1 += predicted_test.eq(var_gt).sum().item()


        if epoch == num_epochs:
            
     
            print('***************************************************')
            predslist.append(preds)
            var_gtlist.append(var_gt1)
            print('done')
 
    test1_accuracy = 100 * correct_test1 / total_test1
    print('Epoch {}, test Loss: {:.3f}'.format(epoch, test1_loss_total / num_steps),
          "Testing Accuracy: %d %%" % (test1_accuracy))
    print("finished training")

and:
I calculate argmax for dim0 that it is for 21 classes.

for ii in predslist:
  for j in range(0, batch_size):
      # print(ii[j])
      print('ii[j].shape:',ii[j].shape)
      om = torch.argmax(ii[j].squeeze(), dim=0).detach().cpu().numpy()
      print ('np.unique(om):' ,np.unique(om)) 
      print('om.shape:',om.shape)
      print('************************************************************')

result:

************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  5  7  9 10 12 13 14 15 16 18 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  7  9 10 12 13 14 15 16 18 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  2  5  7  9 10 11 13 14 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  5  7 10 11 13 14 15 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  5  7 10 11 12 13 14 15 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 1  2  5  7 10 13 14 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 0  1  2  4  5  7 10 11 12 13 14 16 18 19 20]
om.shape: (64, 64, 64)
************************************************************
ii[j].shape: torch.Size([21, 64, 64, 64])
np.unique(om): [ 1  2  5  7 10 11 13 14 16 18 20]
om.shape: (64, 64, 64)

and this is my model:

class Myconv(nn.Module):
    def __init__(self):
        super(Myconv, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=10, kernel_size=(3, 3, 3),stride=1, padding=1)
        self.norm1 =nn.BatchNorm3d(num_features = 10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.max1 = nn.MaxPool3d(kernel_size=(3, 3 ,3), stride=1, padding=1)


        self.conv2 = nn.Conv3d(in_channels=10, out_channels=21, kernel_size=(3, 3, 3),
                               stride=1, padding=1 )
        self.norm2 =nn.BatchNorm3d(num_features = 21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.max2 = nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1)

        self.conv3 = nn.Conv3d(in_channels=21, out_channels=21, kernel_size=(3, 3, 3),stride=1, padding=1)
        self.norm3 =nn.BatchNorm3d(num_features = 21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.max3 = nn.MaxPool3d(kernel_size=(3, 3 ,3), stride=1, padding=1)

    def forward(self, x):
        out1 = self.conv1(x)
        out1 = self.norm1(out1)
        out1 = F1.relu(out1)
        out1 = self.max1(out1)

        out2 = self.conv2(out1)
        out2 = self.norm2(out2)
        out2 = F1.relu(out2)
        out2 = self.max2(out2)

        out3 = self.conv3(out2)
        out3 = self.norm3(out3)
        out3 = F1.relu(out3)
        out3 = self.max3(out3)
        return out3

when I edit paremeters of model and get more accuracy (99%), I get zeroes argmax!!!

@ptrblck, I have few tesnors with size torch.Size([4, 11, 256, 256]), whit unique values [0, 1, 2, 3, 4, 5, 6, 7, 8 ,9, 10], when I use torch.argmax(mask, dim=1), it just returns zero with torch.Size([1])
What is the problem?

mask = tensor([[[[5., 5., 5., ..., 4., 4., 4.], [5., 5., 5., ..., 4., 4., 4.], [5., 5., 5., ..., 4., 4., 4.], ..., [0., 0., 0., ..., 4., 4., 4.], [0., 0., 0., ..., 4., 4., 4.], [0., 0., 0., ..., 4., 4., 4.]], dtype=torch.float64)

mask = torch.unique(torch.argmax(mask, dim=1))
# print(mask.shape, mask)
# torch.Size([1]) tensor([0])

This shouldn’t happen as only one dimension should be reduced:

x = torch.randint(0, 11, [4, 11, 256, 256])
out = torch.argmax(x, dim=1)
print(out.shape)
# torch.Size([4, 256, 256])

In the following code you are additionally using torch.unique so did you check if the output of torch.argmax returns all zeros (which would be the case if this dim contains the max values)?

Yes, output of torch.argmax is all zeros, I used torch.unique to make sure that I have all classes (i.e. values are between 0 and 10), but I noticed it just returns zeros. I used a random input and it works totally fine.
mask = tensor([[[[3., 3., 2., ..., 1., 1., 1.], [3., 3., 2., ..., 1., 1., 1.], [3., 2., 2., ..., 1., 1., 1.], ..., [2., 2., 2., ..., 3., 3., 3.], [2., 2., 2., ..., 3., 3., 3.], [2., 2., 2., ..., 2., 2., [[3., 3., 2., ..., 1., 1., 1.], [3., 3., 2., ..., 1., 1., 1.], [3., 2., 2., ..., 1., 1., 1.], ..., [2., 2., 2., ..., 3., 3., 3.], [2., 2., 2., ..., 3., 3., 3.], [2., 2., 2., ..., 2., 2., [[3., 3., 2., ..., 1., 1., 1.], [3., 3., 2., ..., 1., 1., 1.], [3., 2., 2., ..., 1., 1., 1.], ..., [2., 2., 2., ..., 3., 3., 3.], [2., 2., 2., ..., 3., 3., 3.], [2., 2., 2., ..., 2., 2., 2.]],
# torch.Size([1, 11, 256, 256])

mask = torch.argmax(mask, dim=1)
# tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]])
# torch.Size([1, 256, 256])

Based on the results could you check if the mask contains the largest values in channel0?

x = torch.randint(0, 11, [1, 11, 256, 256])
# make sure channel0 contains the largest values
x[:, 0] += 100000
out = torch.argmax(x, dim=1)

print((out == 0).all())
# tensor(True)

I noticed all channels (i.e. 0-10) has been repeated

mask[0, 0] = 
tensor([[3., 3., 2.,  ..., 1., 1., 1.],
        [3., 3., 2.,  ..., 1., 1., 1.],
        [3., 2., 2.,  ..., 1., 1., 1.],
        ...,
        [2., 2., 2.,  ..., 3., 3., 3.],
        [2., 2., 2.,  ..., 3., 3., 3.],
        [2., 2., 2.,  ..., 2., 2., 2.]])

mask[0, 1] = tensor([[3., 3., 2.,  ..., 1., 1., 1.],
        [3., 3., 2.,  ..., 1., 1., 1.],
        [3., 2., 2.,  ..., 1., 1., 1.],
        ...,
        [2., 2., 2.,  ..., 3., 3., 3.],
        [2., 2., 2.,  ..., 3., 3., 3.],
        [2., 2., 2.,  ..., 2., 2., 2.]])

I understood the issue.
Thank you for your help