I’m quite new to pytorch so I know I might say things which are not really completely correct.
I’m trying out of curiosity to implement a multiclass segmentation using U-net code found here https://github.com/usuyama/pytorch-unet
. Since I have my own dataset of images i’ve modified the Class “SimDataset” to upload my own dataset plus images in the following way:
The images are 224x224 png RGB images and the mask are indexed images with four classes (index 1,2,3 features and 0 background).
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, modelsclass SimDataset(Dataset):
def init(self, image_paths, mask_paths, count , transform=None):
self.image_paths = image_paths
self.mask_paths = mask_pathsdef transforms(self, image, mask): #img = img.resize((wsize, baseheight), PIL.Image.ANTIALIAS) #image = transforms.Resize(size=(64, 64))(image) #mask = transforms.Resize(size=(64, 64))(mask) image = image.resize((64, 64), PIL.Image.NEAREST) mask = mask.resize((64, 64), PIL.Image.NEAREST) image = TF.to_tensor(image) mask = TF.to_tensor(mask) return [image, mask] def __getitem__(self, index): image = Image.open(self.image_paths[index]) mask = Image.open(self.mask_paths[index]) x, y = self.transforms(image, mask) return [x, y] def __len__(self): return len(self.image_paths)
trans = transforms.Compose([
transforms.ToTensor(), transforms.RandomHorizontalFlip()
])train_set = SimDataset(train_paths, train_masks_paths,100, transform = trans)
val_set = SimDataset(train_paths, train_masks_paths,11, transform = trans)image_datasets = {
‘train’: train_set, ‘val’: val_set
}batch_size = 1
dataloaders = {
‘train’: DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
‘val’: DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}dataset_sizes = {
x: len(image_datasets[x]) for x in image_datasets.keys()
}dataset_sizes
I’ve been reading also here that passing the mask as tensor with the command to.Tensor() is not a good habit since the command normalises the values of the class index (and indeed it does) leading to instabilieties. I would like to ask how (using if possible the code snippet i’m using) how to pass correclty the masks during the training.
Also the number of classes that I need to train since I want 3 classes to be recognized in the training is nb_class = 3 ?( n_total_index - 1 meaning 0,1,2,3 but without the background )
Thanks in advance!