Error in python’s multiprocessing library

This is the Unet architecture for the customised dataset for semantic segmentation based on this paper. I am getting this error when the num_worker = 2

`
  File "C:\Users\Neda\Anaconda3\lib\multiprocessing\spawn.py", line 172, in get_preparation_data
    main_mod_name = getattr(main_module.__spec__, "name", None)

AttributeError: module '__main__' has no attribute '__spec__'`

while with the num_worker = 0, I will get this error

raise TypeError((error_msg.format(type(batch[0]))))

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.TiffImagePlugin.TiffImageFile'>

The output of print(device) is cuda:0 I don’t have any clue where I am doing wrong. Thank you in advance.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset  # For custom data-sets
import torchvision.transforms as transforms
from PIL import Image
import glob


print(torch.version)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


# get all the image and mask path and number of images
folder_data = glob.glob("D:\\Neda\\Pytorch\\U-net\\BMMCdata\\data\\*.tif")
folder_mask = glob.glob("D:\\Neda\\Pytorch\\U-net\\BMMCmasks\\masks\\*.tif")

# split these path using a certain percentage
len_data = len(folder_data)
print(len_data)
train_size = 0.6

train_image_paths = folder_data[:int(len_data*train_size)]
# print(train_image_paths) # output is 25 image for train
test_image_paths = folder_data[int(len_data*train_size):]
#print(test_image_paths) # output is 18 image for test

train_mask_paths = folder_mask[:int(len_data*train_size)]
test_mask_paths = folder_mask[int(len_data*train_size):]


class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths):   # initial logic happens like transform

        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transforms = transforms.ToTensor()

    def __getitem__(self, index):

        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        t_image = self.transforms(image)
        return t_image, mask

    def __len__(self):  # return count of sample we have

        return len(self.image_paths)


train_dataset = CustomDataset(train_image_paths, train_mask_paths)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

test_dataset = CustomDataset(test_image_paths, test_mask_paths)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)


class ConvRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
        super(ConvRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                             padding=padding, stride=stride)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        #channels, height, width = in_shape

        self.down1 = nn.Sequential(
            ConvRelu(1, 64, kernel_size=(3, 3), stride=1, padding=0),
            ConvRelu(64, 64, kernel_size=(3, 3), stride=1, padding=0)
            )
        self.maxPool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.down2 = nn.Sequential(
            ConvRelu(64, 128, kernel_size=(3, 3), stride=1, padding=0),
            ConvRelu(128, 128, kernel_size=(3, 3), stride=1, padding=0)
            )
        self.maxPool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.down3 = nn.Sequential(
            ConvRelu(128, 256, kernel_size=(3, 3), stride=1, padding=0),
            ConvRelu(256, 256, kernel_size=(3, 3), stride=1, padding=0)
        )
        self.maxPool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.down4 = nn.Sequential(
            ConvRelu(256, 512, kernel_size=(3, 3), stride=1, padding=0),
            ConvRelu(512, 512, kernel_size=(3, 3), stride=1, padding=0)
        )

        self.maxPool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.center = nn.Sequential(
            ConvRelu(512, 1024, kernel_size=(3, 3), stride=1, padding=0),
            ConvRelu(1024, 1024, kernel_size=(3, 3), stride=1, padding=0)
        )
        self.upSample1 = nn.ConvTranspose2d(1024, 1024, 2, stride=2)

        self.up1 = nn.Sequential(
            ConvRelu(1024, 512, kernel_size=(2, 2), stride=1, padding=0),
            ConvRelu(512, 512, kernel_size=(2, 2), stride=1, padding=0)
        )

        self.upSample2 = nn.ConvTranspose2d(512, 512, 2, stride=2)

        self.up2 = nn.Sequential(
            ConvRelu(512, 256, kernel_size=(2, 2), stride=1, padding=0),
            ConvRelu(256, 256, kernel_size=(2, 2), stride=1, padding=0)
        )

        self.upSample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)

        self.up3 = nn.Sequential(
            ConvRelu(256, 128, kernel_size=(2, 2), stride=1, padding=0),
            ConvRelu(128, 128, kernel_size=(2, 2), stride=1, padding=0)
        )

        self.upSample4 = nn.ConvTranspose2d(128, 128, 2, stride=2)

        self.up4 = nn.Sequential(
            ConvRelu(128, 64, kernel_size=(2, 2), stride=1, padding=0),
        )
        # 1x1 convolution at the last layer
        self.output_seg_map = nn.Conv2d(64, 2, kernel_size=(1, 1), padding=0, stride=1)

    def crop_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            # -c is amount of pad which will add on each side for all dimension
            bypass = F.pad(bypass, (-c, -c, -c, -c))  # (padLeft, padRight, padTop, padBottom)
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        x = self.down1(x)
        out_down1 = x

        x = self.down2(x)
        out_down2 = x
        x = self.maxPool2(x)

        x = self.down3(x)
        out_down3 = x
        x = self.maxPool3(x)

        x = self.down4(x)
        out_down4 = x
        x = self.maxPool4(x)

        x = self.center(x)

        x = self.upSample1(x)
        x = self.up1(x)
        self.crop_concat(x, out_down4)

        x = self.upSample2(x)
        x = self.up2(x)
        self.crop_concat(x, out_down3)

        x = self.upSample3(x)
        x = self.up3(x)
        self.crop_concat(x, out_down2)

        x = self.upSample4(x)
        x = self.up4(x)
        self.crop_concat(x, out_down1)

        out = self.output_seg_map(x)

        return F.log_softmax(self.output_seg_map(out))  # applies log-softmax on last layer


net = UNet()
net = net.to(device)

criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.99)

def main():
    for epoch in range(2):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # get the inputs
            t_image, mask = data
            t_image, mask = t_image.to(device), mask.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

           # forward + backward + optimize
            outputs = net(t_image)
            loss = criterion(outputs, mask)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

print('Finished Training')

if __name__=='__main__':
    main()

The error is most likely thrown because your mask is still a PIL.Image instead of a tensor.
I guess using transforms.ToTensor() wouldn’t be the best idea in this case, as you are dealing with a semantic segmentation task, i.e. your mask should contain class indices for all pixels.

Does you current mask hold a specific color code so that we could use it to create a mask tensor with class indices?

1 Like

@ptrblck when I close the Spyder and run the code again I got this error raise TypeError((error_msg.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.TiffImagePlugin.TiffImageFile'>
I guess should be about a mask that didn’t transform to tensor.

I did check the masks. They include mainly 85 and 170 value, and also some of them have 85,170, and 255 value. The mask looks like this. How can I create a mask tensor with class indices?

Yeah, I think the mask is the problem as described earlier.
Could you upload one sample somewhere else? I need some account to download the image from your university.

I guess you are currently loading a “normal” image using a specific color code for certain classes, so that we would need a mapping to decode them to class indices.

Which dataset are you using? Did you create it yourself?

@ptrblck yes, at the moment I am loading a normal image. Currently, I am using the TEM dataset in this link. I downloaded the BMMC dataset and BMMC segmentation as a mask. For now, I am using this dataset to understand how should I solve segmentation task. Possibly my data will have a class of 0 and 1. 0 is background and 1 is left ventricle.

My dataset will be different. It’s cardiac ultrasound images and mask of the left ventricle that needs to be annotated by a cardiologist.

Thanks for the link!
I’ve downloaded the BMMC dataset with the masks and it seems three different color codes are used for the masks as you already explained.
This code should load masks and convert these values to class indices:

class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths):   # initial logic happens like transform
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transforms = transforms.ToTensor()
        self.mapping = {
            85: 0,
            170: 1,
            255: 2
        }
    
    def mask_to_class(self, mask):
        for k in self.mapping:
            mask[mask==k] = self.mapping[k]
        return mask

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        t_image = self.transforms(image)
        mask = torch.from_numpy(np.array(mask))
        mask = self.mask_to_class(mask)
        return t_image, mask

    def __len__(self):  # return count of sample we have
        return len(self.image_paths)


train_dataset = CustomDataset(train_image_paths, train_mask_paths)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

for data, target in train_loader:
    print(torch.unique(target))

You can of course change the mapping, e.g. if the pixel value 85 should be mapped to another class index.
Just make sure your classes start with 0 and end with num_classes-1 in case you would like to use a classification criterion like nn.CrossEntropyLoss.

3 Likes

@ptrblck Thanks a lot for the effort you put in this forum. I really appreciate it. At the moment I am getting CUDA error: out of memory. I will sort out memory issue and will update you.

1 Like

@ptrblck my dedicated video memory is 8192 MB. Shouldn’t be enough to train this UNet model? Although image size is really big 1024*1024

What’s the batch_size? You may need to use a smaller batch_size.

@vmirly1 bach_size is 4. I tried with smaller and got out of memory error.

O I see!

I have never worked with UNet. Do you really need to work with such big image sizes? At least you can downsample the input images for debugging purposes.

1 Like

1024 is pretty large for this kind of model.
In the original paper the authors also used a sliding window approach to get their model working.

I tried to debug the model as @vmirly1 suggested, however, there seem to be a few issues in your code.
If I resize the images to 512 or 256, I get a size mismatch error even if I use crop=True in the crop_concat layers. I fixed it with an ugly hack to get it working.
However, I also realized that you are actually never using the concatenated output of the crop_conat layers.
If you assign the result back to x, you’ll get another error, since the number on input channels won’t match for your conv layers.

If I use my implementation as a drop-in replacement for your model, I can get the model working with a batch size of 1 for the original image shapes using approx. 7GB of GPU memory.
Note that the implementations differ a bit, e.g. your model seem to have an additional layer in the bottleneck, so that will most likely explain the difference in memory usage.

2 Likes

@ptrblck Thank you for spending time and explain to me about issues in my code. Does your unet model works for images with 1 channel? I think your dataset was RGB when you have written this model. what is in_channels_skip in your model?

1 Like

Yes, you can pass the number of input channels using in_channels.
This should work for your dataset:

model = UNet(in_channels=1,
             out_channels=64,
             n_class=3,
             kernel_size=3,
             padding=1,
             stride=1)

in_channels_skip is the number of channels from the block in the downward path which should be concatenated with the other input to the current up block.

So lets take up3 as an example.
This block takes an activation from down3 as its input (in_channels = 8 * out_channels ) and additionally concatenates it with the output of down2 (in_channels_skip = 4 * out_channels ).
So you’ll have 12 * out_channels channels for your conv layer in up3 now, which is why it’s defined in this way:

self.conv_block = BaseConv(
    in_channels=in_channels + in_channels_skip,
...

Does this example make it clear in any way or do you need more information?

1 Like

@ptrblck thanks a lot for example. Before I make sure I understand this, I have two more questions. One is what
is the second in_channels in self.conv_trans1?

class UpConv(nn.Module):
    def __init__ (self, in_channels, in_channels_skip, out_channels, kernel_size, padding, stride):
        super(UpConv, self).__init__()
        self.conv_trans1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, padding=0, stride=2)

and also why you did 2 * out_channels in self.down1 layer in UNet class?
sorry, I am asking a lot of question. I am follwoing your model before I go to debug mine.

The transposed conv layer just doubles the spatial size without changing the number of channels, so that is why I’m passing in_channels twice. I realize that the naming of my variables is really bad, sorry about that.

I’m basically multiply the number of output channels by a power of 2 for each down conv, i.e. 2*out_channels, 4*out_channels, 8*out_channels.
You might want to change the multiplicative factor, if you want to increase the number of channels. out_channels is passed to the __init__ of the model.

1 Like

@ptrblck Thanks a lot for clarifying on this. Do you have any figure for your model like fig 1 in this paper? Just wanted to know did you written this model based on a research paper?

No, unfortunately I didn’t create such a nice figure for this implementation.
Yes, I used UNet as the base model and have rewritten it a bit to work with differently shaped inputs and also to output the same spatial size (instead of a cropped version like in the original implementation).

1 Like

@ptrblck That’s great. Does it mean in your model the size of output will be same as input size? I can see you used Adam as an optimizer and in the paper which I sent earlier, they used SGD. Any specific reason that you chose this as an optimizer?

Yes, the output size should match the input size, which makes it a bit easier for segmentation tasks, especially if you have valid classes near the borders.

Yeah, just my default optimizer based on my biased decisions for a good baseline. :wink:

1 Like