Combination of MRI image and its segmentation

Hi I wanted to train a model for survival time prediction with MRI images. Suppose I have both images and labels and I want to make a dataset that contains both image and label. What can I do for example is there any way to give segmentation as a new channel map? Or is multiplying the tensors of images and labels, a good way? I dont have experience , if someone has please guide me.
My purpose is that my survival prediction model just look at the tumors part instead of whole organ.

Hi Maryam!

You can look at it both ways.

For context, as I understand it, you have an MRI image. To be concrete, let’s
imagine that this is a single-channel, gray-scale image where the pixel values
run from 0.0 to 1.0.

Then you also have a label (that comes from some external source and is
not what you are trying to train your network to predict). This label is also an
“image” that is the same size as your image. To be concrete, lets imagine
that this is also a single-channel, gray-scale image where the pixel values
are 0.0 for pixels in the actual image that are non-tumor pixels and are 1.0
for pixels in the actual image that are tumor pixels.

One the one hand, you could combine your image and label into a single
image by multiplying them together. Now your non-tumor pixels are masked
out. That is to say that what were non-tumor pixels are now “background”
pixels with value 0.0. Your tumor pixels remain whatever they were in your
original image.

The advantage of this approach is that it is easier for the network in that you
are only giving the network the “important” pixels to process. The network
doesn’t have to “learn” that your label (the zero-one mask image) is telling
it what pixels are important. If you can “tell” your network something, rather
than making your network “learn” it, training the network is easier.

On the other hand, you could combine your image and label into a two-channel
image by stack()ing them together. This has the following potential advantage:
I could well believe that the potential malignancy of tumor depends not only on
the tumor tissue itself – the tumor pixels in your image – but on the surrounding
tissue as well – your non-tumor pixels. Perhaps it matters for survival whether
the tissue surrounding the tumor is dense or vascularized or whatever.

If you mask out (set to zero) the non-tumor pixels, the network won’t have
any information about the non-tumor tissue, so it won’t be able to take the
structure of the non-tumor tissue into account when learning to make its
survival prediction.

So I think that your answer depends on the answer to a biology question.
If the likelihood of survival really only depends on the characteristics of the
tumor tissue, then the approach of using a single-channel input image with
the non-tumor pixels masked out will likely be better (because it makes the
network’s learning task somewhat easier). But if survival depends in a
significant way also on the non-tumor tissue, you will want the approach
of using a two-channel image – the original image stack()ed with the label
“image” – because that’s how you pass in the non-tumor-tissue information
that the network needs to make the best prediction about survival.


K. Frank

Thanks for your complete answer. You are completely right because when I spoke to clinicians they told me the surrounding of the tumor is important but I have some technical issues.
I used monai for making my dataset because they were 3d and handling 3d images is easier with monai. I made a dataset that has the image and the groundtruth. Then I tried to make a custom dataset with pytorch that get image and groundtruth and as you told me I stacked them together.

def __getitem__(self, index):
            if type(index) is not int:
                raise ValueError(f"Need `index` to be `int`. Got {type(index)}.")
            #extract the image pixels from the dictionary of dataset
            img = self.image_dataset[index]['image']
            lbl = self.image_dataset[index]['label']
            concatim = torch.stack((img,lbl))
            return concatim, (self.time[index], self.event[index])

Then I made dataloader of this custom dataset:

def collate_fn(batch):
    """Stacks the entries of a nested tuple"""
    return tt.tuplefy(batch).stack()

dl_train = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=collate_fn)
dl_val = DataLoader(dataset_val, batch_size=1, shuffle=True, collate_fn=collate_fn)

And in my model I tried to change the input channel of 3d CNN to 2. This is my model:

    class Net(nn.Module):
        def __init__(self, out_features):
            # numbers in front of conv3d are input channel,number of filters,kernel size,stride
            self.conv1 = nn.Conv3d(1, 16, 5, 1)
            self.max_pool = nn.MaxPool3d(2)
            self.conv2 = nn.Conv3d(16, 16, 5, 1)
            self.glob_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
            self.fc1 = nn.Linear(16, 16)
            self.fc2 = nn.Linear(16, out_features)
            self.dropout = nn.Dropout3d(0.25)
            self.dropout1 = nn.Dropout(0.25)
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = self.max_pool(x)
            x = self.dropout(x)
            x = F.relu(self.conv2(x))
            x = self.glob_avg_pool(x)
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = self.dropout1(x)
            x = self.fc2(x)
            return x

After fitting my model I faced the error :
RuntimeError: Expected 5-dimensional input for 5-dimensional weight [16, 1, 5, 5, 5], but got 6-dimensional input of size [1, 2, 1, 128, 128, 128] instead
16 is the number of my first CNN filters 5 is the size of filter. 128 is the width, height and depth of my 3d images.

Hi Maryam!

It looks like your img and lbl already have an nChannel dimension (with
nChannel = 1), so you will want to use cat(), rather than stack(). See

This should be:

            self.conv1 = nn.Conv3d (2, 16, 5, 1)

(In the code you posted, you have input channels of 1, rather than 2.)

Use cat() rather than stack(). (stack() adds a new dimension to your
tensor, while cat() uses the nChannel dimension that you already have.)


>>> import torch
>>> torch.__version__
>>> nBatch = 1
>>> nChannel = 1
>>> hwd = 128
>>> img = torch.randn (nBatch, nChannel, hwd, hwd, hwd)
>>> lbl = torch.randn (nBatch, nChannel, hwd, hwd, hwd)
>>> img.shape   # has nBatch = 1, nChannel = 1
torch.Size([1, 1, 128, 128, 128])
>>> concatim = ((img, lbl), dim = 1)
>>> concatim.shape   # has nBatch = 1, nChannel = 2
torch.Size([1, 2, 128, 128, 128])
>>> conv = torch.nn.Conv3d (in_channels = 2, out_channels = 16, kernel_size = 5)   # takes 2 input channels
>>> conv (concatim).shape
torch.Size([1, 16, 124, 124, 124])


K. Frank

Why did you choose dim=1 in this part?

concatim = ((img, lbl), dim = 1)

And because I do not have the number of batches and number of channels in my custom dataset yet ( in my mind they are produced after making dataloader)so should I concat them in dimension 0? I wrote this one
concatim =, lbl),0)
And I got this error
Given groups=1, weight of size [16, 1, 5, 5, 5], expected input[1, 2, 128, 128, 128] to have 1 channels, but got 2 channels instead
And I changed the first conv input channel to 2

    class Net(nn.Module):
        def __init__(self, out_features):
            #numbers in front of conv3d are input channel,number of filters,kernel size,stride
            self.conv1 = nn.Conv3d(2, 16, 5, 1)

And when I tried to contact them in dimension 1 as you wrote here I got this error Sizes of tensors must match except in dimension 1. Got 115 and 113 in dimension 2 (The offending index is 1)

Befor that I tried to change the images and labels dimension to 128 with Monai. You can see it in Resized:

train_transforms = Compose(
            LoadImaged(keys=['image', 'label']),
            #add channels to the images that monai can interpret in this way
            AddChanneld(keys=['image', 'label']),
            #rescale the voxels pixdim=(height,width,depth) 
            Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
            Orientationd(keys=['image', 'label'], axcodes="RAS"),
            ScaleIntensityRanged(keys='image', a_min=-100, a_max=1000, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=['image', 'label'], source_key='image'),
            Resized(keys=['image', 'label'], spatial_size=[128,128,128]),
            #to tensor should be the last transform
            ToTensord(keys=['image', 'label'])

Do you think where the problem is?

And I have to add that the output size of img and lbl are

torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])

Hi Maryam!

I chose dim = 1 because I was assuming the img and lbl already
had an nBatch dimension so that the nChannel dimension would be
in position 1.

Everything you’ve posted suggests that you do have an nChannel
dimension (for img and lbl), even if that dimension has size 1.

To check this you could print out img.shape and lbl.shape right
before calling concatim =, lbl),0). I expect
that the two shapes will be the same and be [1, 128, 128, 128].
If this is the case then you do have an nChannel = 1 dimension
(and no nBatch dimension yet).

It does make sense that you do not yet have an nBatch and that the
nBatch dimension that appears later is indeed added by the Dataloader.

Yes, if you have an nChannel dimension, but do not have an nBatch
dimension, you do want concatim =, lbl),0),
that is, with dim = 0.

Did you change “the first conv input channel to 2” before you got this
error? The weight in the error message you quoted only has one input
channel. Consider:

>>> import torch
>>> torch.__version__
>>> torch.nn.Conv3d (in_channels = 1, out_channels = 16, kernel_size = 5).weight.shape
torch.Size([16, 1, 5, 5, 5])
>>> # versus
>>> torch.nn.Conv3d (in_channels = 2, out_channels = 16, kernel_size = 5).weight.shape
torch.Size([16, 2, 5, 5, 5])

As noted above, call cat() with dim = 0 (not dim = 1).

Regardless of whether you used dim = 1 or dim = 0 with cat(), this error
message means that you’ve changed something else along the way.

Print out img.shape and lbl.shape right before calling cat().

Is it possible that the shapes of your images (and labels) change from
iteration to iteration so that sometimes the call to cat() works (when the
shapes match), but then the call to cat() fails in a later iteration when
the shapes don’t match?

But in what version of your code do img and lbl have these shapes?
Print them out right before the call to cat() that is giving you the error.


K. Frank

1 Like