How to improve custom Dataset class for reading DICOM images?

Hi everyone!
I’m very new to PyTorch or python although I know basics of programming. I’m trying to process some MR images in DICOM format to classify them into two classes. I’ve created a custom dataset class (code bellow) and I would like to know if I’m thinking it right.

So, my questions are:

How can I improve my code? Is it a good idea to create a list (image pixel_data, label) or should I use another approach? In __getitem__ is it a good practice to use a dictionary?

import os
import torch
import pydicom
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
"""
# Train images: ImageData/train/0; ImageData/train/1
# Test images: ImageData/test/0; ImageData/test/1
# Classes: "not_operable": 0; "operable":1
#
# set_file_matrix():
# Count total items in sub-folders of root/image_dir:
# Create a list with all items from root/image_dir "(pixel_array, label)"
# Takes label from sub-folder name "0" or "1"
"""
class DicomDataset(Dataset):
    def __init__(self, root, image_dir):
        self.image_dir = os.path.join(root, image_dir)  # ImageData/train or ImageData/test
        self.data = self.set_file_matrix()
        self.transform = transforms.Compose([
            transforms.Resize(128),
            transforms.CenterCrop(128)
        ])

    def __len__(self):
        return len(self.data)

    def set_file_matrix(self):
        # count elements
        total = 0
        root = self.image_dir
        folders = ([name for name in os.listdir(root)
                    if os.path.isdir(os.path.join(root, name))])
        for folder in folders:
            new_path = os.path.join(root, folder)
            contents = len([name for name in os.listdir(new_path) if os.path.isfile(os.path.join(new_path, name))])
            total += contents

        # create list(img_name, label)
        files = []
        labels = ([name for name in os.listdir(root)
                   if os.path.isdir(os.path.join(root, name))])
        for label in labels:
            new_path = os.path.join(root, label)
            file_list = os.listdir(new_path)
            for file in file_list:
                files.append([file, label])

        return files

    def __getitem__(self, index):
        image_file = pydicom.dcmread(os.path.join(self.image_dir, self.data[index][1], self.data[index][0]))
        image = np.array(image_file.pixel_array, dtype=np.float32)[np.newaxis]  # Add channel dimension
        image = torch.from_numpy(image)
        label = float(self.data[index][1])
        if self.transform:
            image = self.transform(image)

        return {'image': image, 'label': label}

Thanks in advance!

Your code looks generally fine.
This line of code looks a bit strange:

label = float(self.data[index][1])

as I would assume self.data contains (part of) a path.

Thanks @ptrblck !

About that strange line maybe there is something that I’m not doing right.

This is the way I load the dataset:

train_set = DicomDataset(ROOT_PATH, 'train')
test_set = DicomDataset(ROOT_PATH, 'test')
train_set_loader = torch.utils.data.DataLoader(train_set, batch_size=5, shuffle=True)
test_set_loader = torch.utils.data.DataLoader(test_set, batch_size=5, shuffle=True)

And this is the way I iterate over it in my model:

# more code before
 n_total_steps = len(self.train_set_loader)
        for epoch in range(self.num_epochs):
            i = 0
            for item in enumerate(self.train_set_loader):
                images = item[1]["image"].to(device)
                labels = item[1]["label"].to(device)

# some more code after

If I don’t use float() I get this error : AttributeError: 'list' object has no attribute 'to' or AttributeError: 'str' object has no attribute 'to'

Is it a good practice use float() or should I use another approach to read/save/load/iterate “labels” ?

You shouldn’t be able to directly convert a string to a floating point number:

float('a')
> ValueError: could not convert string to float: 'a'

so I’m not sure what exactly is stored in self.data[index][1].
Based on the error messages it seems it might be a list or a str.

Generally, you would have to make sure the targets are created as valid tensors (e.g. as a class index).
Depending how you are loading the dataset, you might need to create the class indices from the number of folders or any other mapping.

I haven’t done much with DICOMs but here is a Dataset I made to read NIfTI files, although it was intended for segmentation not classification. I haven’t given it a proper test as when I tried it I had a mistake with my segmentation masks, but I’m reasonably sure that was why my results were bad.

Depending on what I was trying to do I’ve used lists and dictionaries when assembling my dataset, depending on how my data was saved out. Neither method has given me issues, but generally the training process works much, much faster and the datasets are a bit simpler if you do most of the preprocessing for your data and save it out as torch tensors beforehand (contrast above to this dataset). Doubly so if your files need normalization applied.

1 Like

self.data[index][1] stores labels as '0'or '1', I think that’s the reason why float( ) gives no error. About the recommendation, I think I’m not understanding what “class index” is. Is it a specific class like class class_name() ... or something else?

Thanks! I’ll take a look to your code.

Ah, OK, that makes sense then and should be the right approach.

1 Like