Hi everyone!
I’m very new to PyTorch or python although I know basics of programming. I’m trying to process some MR images in DICOM format to classify them into two classes. I’ve created a custom dataset class (code bellow) and I would like to know if I’m thinking it right.
So, my questions are:
How can I improve my code? Is it a good idea to create a list (image pixel_data, label) or should I use another approach? In __getitem__ is it a good practice to use a dictionary?
import os
import torch
import pydicom
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
"""
# Train images: ImageData/train/0; ImageData/train/1
# Test images: ImageData/test/0; ImageData/test/1
# Classes: "not_operable": 0; "operable":1
#
# set_file_matrix():
# Count total items in sub-folders of root/image_dir:
# Create a list with all items from root/image_dir "(pixel_array, label)"
# Takes label from sub-folder name "0" or "1"
"""
class DicomDataset(Dataset):
def __init__(self, root, image_dir):
self.image_dir = os.path.join(root, image_dir) # ImageData/train or ImageData/test
self.data = self.set_file_matrix()
self.transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128)
])
def __len__(self):
return len(self.data)
def set_file_matrix(self):
# count elements
total = 0
root = self.image_dir
folders = ([name for name in os.listdir(root)
if os.path.isdir(os.path.join(root, name))])
for folder in folders:
new_path = os.path.join(root, folder)
contents = len([name for name in os.listdir(new_path) if os.path.isfile(os.path.join(new_path, name))])
total += contents
# create list(img_name, label)
files = []
labels = ([name for name in os.listdir(root)
if os.path.isdir(os.path.join(root, name))])
for label in labels:
new_path = os.path.join(root, label)
file_list = os.listdir(new_path)
for file in file_list:
files.append([file, label])
return files
def __getitem__(self, index):
image_file = pydicom.dcmread(os.path.join(self.image_dir, self.data[index][1], self.data[index][0]))
image = np.array(image_file.pixel_array, dtype=np.float32)[np.newaxis] # Add channel dimension
image = torch.from_numpy(image)
label = float(self.data[index][1])
if self.transform:
image = self.transform(image)
return {'image': image, 'label': label}
Thanks in advance!