Understanding how to label/target tensors for 3D volumes

Coming back to this old question of mine, I’m trying to replicate this simple example for a 1 channel segmentation using BCE but it does not seem to be working

torch.manual_seed(2809)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

data = torch.randn(1, 1, 50, 300, 300)
label = torch.zeros(1, 1, 50, 300, 300)
label[0, 20:32, 80:150, 80:150] = 1

data = data.to(device)
label = label.to(device)

model = nn.Sequential(
        nn.Conv3d(1, 32, 3, 1, 1),
        nn.ReLU(),
        nn.Conv3d(32, 16, 3, 1, 1),
        nn.ReLU(),
        nn.Conv3d(16, 1, 3, 1, 1)
    )

model = model.to(device)

loss_criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

for epoch in range(50):
    optimizer.zero_grad()

    output = model(data)
    loss = loss_criterion(output, label)
    loss.backward()
    optimizer.step()

    print('Epoch {}, loss {}'.format(epoch, loss.item()))

print('training finished')
Epoch 0, loss 0.6588255763053894
Epoch 1, loss 0.0022372708190232515
Epoch 2, loss 1.3424497410596814e-05
Epoch 3, loss 4.030329137094668e-07
Epoch 4, loss 2.9725489625320733e-08
Epoch 5, loss 3.917283386556392e-09
Epoch 6, loss 8.141641250070109e-10
Epoch 7, loss 2.1598697830249591e-10
Epoch 8, loss 6.574558952809895e-11
Epoch 9, loss 2.1669017274961178e-11
Epoch 10, loss 7.65580914635633e-12
Epoch 11, loss 2.8875027369146267e-12
Epoch 12, loss 1.1126182440646115e-12
Epoch 13, loss 4.768368859486838e-13
Epoch 14, loss 2.1192757564039016e-13
Epoch 15, loss 1.0596380137272224e-13
Epoch 16, loss 5.298190068636112e-14
Epoch 17, loss 2.6490952037246454e-14
Epoch 18, loss 0.0
Epoch 19, loss 0.0
Epoch 20, loss 0.0
Epoch 21, loss 0.0
Epoch 22, loss 0.0
Epoch 23, loss 0.0
Epoch 24, loss 0.0

The loss goes to 0 but I still cant seem to extract an output for this simple example.

First off the argmax() does not seem to be working for me even for a single channel. As was done in the example above, to get the prediction from this I use

torch.argmax(output.detach(), 1).squeeze()

which should give me the location of maximum values in the channel dimension, but instead I get a tensor containing only zeros. Even max(torch.argmax()) produces 0 . I asked about this before as well, where I was again getting 0 from argmax but there the maximum values were located in another channel, but in case of a single channel here I still am getting 0?

As a work around I tried to threshold the output with pred = (output > 0) * 1 but still am only getting 0’s.

The answer could be that the network converged to the background, but my case here is similar to the quoted code(which works great), only difference is that I have a single output channel and am using BCEwithLogits.

Would be very helpful to get an answer on why I am having trouble here.

Many thanks