Understanding how to label/target tensors for 3D volumes

I’ve understood the process of labeling for semantic segmentation for 2D images. I was able to create label or target tensors using a colour coded method provided for the dataset. The colour codes provided were:

    ("Animal", np.array([64, 128, 64], dtype=np.uint8)),
    ("Archway", np.array([192, 0, 128], dtype=np.uint8)),
    ("Bicyclist", np.array([0, 128, 192], dtype=np.uint8)),
    ("Bridge", np.array([0, 128, 64], dtype=np.uint8)),
    ("Building", np.array([128, 0, 0], dtype=np.uint8)),
    ("Car", np.array([64, 0, 128], dtype=np.uint8)),
    ("CartLuggagePram", np.array([64, 0, 192], dtype=np.uint8)),
    ("Child", np.array([192, 128, 64], dtype=np.uint8)),

To create a label for my network I would simply locate these RGB colours in the images and compare each “object” with its respective colour code and then give it a value in a grayscale image. So each label would be a single channel tensor. And the output of my system would have channels equal to the number of classes in the dataset (for CamVid this was 32).

No I am trying to do the same in 3D and am struggling to understand what to do. The problem is a binary one as I have to detect a single item in the input so basically I should have two output channels: Background and foreground. I have a 3D single channel input, the input having the shape: [BS, Channel, Z, X, Y].

For this input, I have a tensor which indicates the locations of my object of interest in the input. A 3D Gaussian is built around each point to be measured.

I hope I have explained my situation clearly.

My confusion is how to create a label tensor in 3D for segmenting my points of interest as I have done for the CamVid dataset.

Below is a 2D representation of a target that has 3D Gaussians drawn on the points of interest, the tensor is of the size [16, 64, 64]

target

Here is the target file as a pickle dump (just in case): Box

Similar to the CamVid case as I would have 1 channel per class, here I know that I would have 2 channels, one for the background and 1 for the foreground, but how to label these 3D tensors is what I dont understand.

Basically you could handle it like a 2-dimensional convolution with another “spatial” dimension.
I.e. the target should contain the class indices without a channel dimension in the shape [batch_size, d, h, w].
I’ve created a small dummy example using a simple model to segment a square in the volume:

torch.manual_seed(2809)

c, d, h, w = 3, 12, 12, 12
data = torch.randn(1, c, d, h, w)
target = torch.zeros(1, 12, 12, 12, dtype=torch.long)
target[0, 4:8, 4:8, 4:8] = 1

model = nn.Sequential(
    nn.Conv3d(c, 6, 3, 1, 1),
    nn.ReLU(),
    nn.Conv3d(6, 2, 3, 1, 1)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(800):
    optimizer.zero_grad()
    
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    
    print('Epoch {}, loss {}'.format(epoch, loss.item()))


pred = torch.argmax(output.detach(), 1).squeeze(0)
pred.unique()
rows, cols = 4, 3
fig, axarr = plt.subplots(rows, cols)
for idx in range(pred.size(0)):
    axarr[idx//cols, idx%cols].imshow(pred[idx])

fig, axarr = plt.subplots(rows, cols)
for idx in range(target.size(1)):
    axarr[idx//cols, idx%cols].imshow(target[0, idx])

Let me know, if that helps in creating your model.

1 Like

Many thanks @ptrblck this did clear up a lot of my doubts about 3D labeling. I see that the output has 2 channels whereas the target has only 1 channel. I was under the impression that they both needed to be of the same dimensions. I see that is not the case.

So I dont have to do any special processing with my label data as they are already in the required format (A zero tensor with 3D gaussians on the points of interest).

Can you please clarify what is the purpose of pred.unique() in the above code. As with or without this line I am getting the same result.

Thanks again

Currently your target has float values, which won’t work using nn.CrossEntropyLoss.
You should transform these values to class indices, e.g. by using a threshold.
This will make sure that the center of the gaussian is marked as class1, while all other voxels belong to class0 (or vice versa).

Haha, just skip the pred.unique() part. I used it just for debugging purposes to see when the model learned to predict ones and forgot to remove it. :wink:

1 Like

Hi @ptrblck, revisiting this old post of mine to clarify one more thing about the label tensor. I see that nn.CrossEntropyLoss() does not need the output and target tensor to be of the same size and I could compare the output of having 2 channels: [1,2,16,64,64] with the binary target having only a single channel [1,16,64,64]. And as our sample code it is working fine for me with my data as well.

But from what I understand DiceLoss() and BinaryCrossEntropy() need both to be of the same size. If I want to use such loss criterion how can I have my label tensor to be 2 channels? Can I make a label tensor to have 2 channels? one for background and one for foreground? Or would it be better to change the last layer of the network to output only a single channel and output a size [1,1,16,64,64]?

Thank you

For a mutually exclusive binary classification, i.e. only one class can be set at any time, I would change the output channel to 1.
If you are dealing with a multi-label classification, i.e. both classes can be set for a sample, you should use multiple channels, but this doesn’t seem to make much sense using a background and foreground class. :wink:

1 Like

Learn something new everyday. Many thanks as always :slight_smile:

1 Like

As suggested I changed my output_channels to 1 but now I am getting the following error:

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes’ failed.

Searching the error suggests that my labels are not zero indexed. I guess that is the reason it was working with output_channels=2 and is failing when the output channels is the same as that of the target.

How would I zero index my target tensor to satisfy this condition?

It looks like you are still using nn.CrossEntropyLoss. I understood that you would like to use nn.BCEwithLogitsLoss for the binary classification task.

oh yes, sorry I did not clarify. I am running three instances to compare the results. I’m trying to compare CrossEntropy, BinaryCrossEntropy and DiceLoss

OK, sorry for the misunderstanding.
In this case:

  • for nn.CrossEntropyLoss you should use nb_classes channels in your output. The target does not contain the channel dimension (but all others) and stores class indices of type torch.long.
  • nn.BCEwithLogitsLoss should get tensors with the same dimension (and the same floating type) for your output and target.
  • usually the dice loss needs predictions, i.e. you would apply a threshold on the sigmoid of your output. However, this depends on your implementation of the dice loss. The shapes should be the same for your output and target.
1 Like

Coming back to this old question of mine, I’m trying to replicate this simple example for a 1 channel segmentation using BCE but it does not seem to be working

torch.manual_seed(2809)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

data = torch.randn(1, 1, 50, 300, 300)
label = torch.zeros(1, 1, 50, 300, 300)
label[0, 20:32, 80:150, 80:150] = 1

data = data.to(device)
label = label.to(device)

model = nn.Sequential(
        nn.Conv3d(1, 32, 3, 1, 1),
        nn.ReLU(),
        nn.Conv3d(32, 16, 3, 1, 1),
        nn.ReLU(),
        nn.Conv3d(16, 1, 3, 1, 1)
    )

model = model.to(device)

loss_criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

for epoch in range(50):
    optimizer.zero_grad()

    output = model(data)
    loss = loss_criterion(output, label)
    loss.backward()
    optimizer.step()

    print('Epoch {}, loss {}'.format(epoch, loss.item()))

print('training finished')
Epoch 0, loss 0.6588255763053894
Epoch 1, loss 0.0022372708190232515
Epoch 2, loss 1.3424497410596814e-05
Epoch 3, loss 4.030329137094668e-07
Epoch 4, loss 2.9725489625320733e-08
Epoch 5, loss 3.917283386556392e-09
Epoch 6, loss 8.141641250070109e-10
Epoch 7, loss 2.1598697830249591e-10
Epoch 8, loss 6.574558952809895e-11
Epoch 9, loss 2.1669017274961178e-11
Epoch 10, loss 7.65580914635633e-12
Epoch 11, loss 2.8875027369146267e-12
Epoch 12, loss 1.1126182440646115e-12
Epoch 13, loss 4.768368859486838e-13
Epoch 14, loss 2.1192757564039016e-13
Epoch 15, loss 1.0596380137272224e-13
Epoch 16, loss 5.298190068636112e-14
Epoch 17, loss 2.6490952037246454e-14
Epoch 18, loss 0.0
Epoch 19, loss 0.0
Epoch 20, loss 0.0
Epoch 21, loss 0.0
Epoch 22, loss 0.0
Epoch 23, loss 0.0
Epoch 24, loss 0.0

The loss goes to 0 but I still cant seem to extract an output for this simple example.

First off the argmax() does not seem to be working for me even for a single channel. As was done in the example above, to get the prediction from this I use

torch.argmax(output.detach(), 1).squeeze()

which should give me the location of maximum values in the channel dimension, but instead I get a tensor containing only zeros. Even max(torch.argmax()) produces 0 . I asked about this before as well, where I was again getting 0 from argmax but there the maximum values were located in another channel, but in case of a single channel here I still am getting 0?

As a work around I tried to threshold the output with pred = (output > 0) * 1 but still am only getting 0’s.

The answer could be that the network converged to the background, but my case here is similar to the quoted code(which works great), only difference is that I have a single output channel and am using BCEwithLogits.

Would be very helpful to get an answer on why I am having trouble here.

Many thanks

argmax won’t work using a single channel, since this single channel will always contain the maximal value, so you’ll get a prediction containing all zeros.

However, the threshold method should work.
Could you apply a threshold of 0.5 and try it again?

Am still getting all 0’s. This may be because of the max value of output:

print(output.max())
tensor(-23.1737, device='cuda:0', grad_fn=<MaxBackward1>)

OK, I see.
The issue is setting your labels to one.

label[0, 20:32, 80:150, 80:150] = 1

Since you are missing a dimension (label was initialized as torch.zeros(1, 1, 50, 300, 300)), you are trying to index 20:32 in dim1, which will return an empty slice.
Thus labels only contain zeros without the manually assigned target in the volume.
Change the line to:

label[0, 0, 20:32, 80:150, 80:150] = 1

and run it again.

1 Like

Ah right, good catch.

Ive made the change but I’m still not getting this simple cube segmented by the network. Ive let the code run for 800 epochs and outpu.max() produces: tensor(-2.1355, device='cuda:0', grad_fn=<MaxBackward1>).

Therefore thresholding at output > 0 produces another 0 tensor

It seems the hyperparameters / architecture is not working for this kind of data.
I could overfit a small example using a lower learning rate:

data = torch.randn(1, 1, 5, 10, 10)
label = torch.zeros_like(data)
label[0, 0, 2:4, 2:7, 2:7] = 1
label.unique()
data = data.to(device)
label = label.to(device)

model = nn.Sequential(
        nn.Conv3d(1, 32, 3, 1, 1),
        nn.ReLU(),
        nn.Conv3d(32, 1, 3, 1, 1),
        nn.ReLU(),
        nn.Conv3d(1, 1, 3, 1, 1, bias=False)
    )

model = model.to(device)

loss_criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(800):
    optimizer.zero_grad()

    output = model(data)
    loss = loss_criterion(output, label)
    loss.backward()
    optimizer.step()

    print('Epoch {}, loss {}'.format(epoch, loss.item()))

print('training finished')


pred = torch.sigmoid(output) > 0.5
print((pred.float() == label).float().sum() / label.nelement())

If you run your code, check the bias of your last conv layer, as in my runs apparently the activation was “killed” and the model output was a constant value given by the bias.
In my experience this happens if your learning rate is too high and throws the parameters off.

PS: The threshold of 0.5 was of course a bad idea, as I didn’t realize you are using logits, not probabilities.

1 Like

Thank you, that seems to be whats happening here. I tried lowering the learning rate but did not help.
My code is the same as I posted above, just a simple model with sample data. These little issues is what I want to learn and iron out before going to complex models.

Can you please advise how to check the bias and the activations and how would I be able to identify this happening in my models? I tried print(list(model.parameters())) but from the output printed how can I identify that the network stopped learning?

thank you

Can I ask you what you do for your argmax problem?

I saw your post here that it is silmilar of my problem in this post.
for get non zeros argmax, if I don’t use relu in model and low learning rate, the accuracy in this model is very low.
what can i do for this?
Also, I wrote a model( with 3 conv and norm and without any activation function) that get accuracy 99% , but argmax is still zero!!