This is my new dataset class:
class MyDataset(Dataset):
def __init__(self, root_dir_img, root_dir_gt, transform=None):
self.root_dir_img = root_dir_img
self.root_dir_gt = root_dir_gt
self.transform = transform
self.img_names = [os.path.join(root_dir_img, name) for name in os.listdir(root_dir_img)]
self.gt_names = [os.path.join(root_dir_gt, name) for name in os.listdir(root_dir_gt)]
self.img_names.sort()
self.gt_names.sort()
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img = Image.open(self.img_names[idx])
gt = Image.open(self.gt_names[idx])
sample = {'image': img, 'mask': gt}
if self.transform:
sample = self.transform(sample)
# img = sample['image'] # can I remove this lines?
# gt = sample['mask']
return img, gt
And this is the code I’m using to transform (I’m posting only the parts I modified from the basic pytorch transforms.py, tell me if you need something more):
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))
if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))
if saturation > 0:
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))
if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))
np.random.shuffle(transforms)
transform = Compose(transforms)
return transform
def __call__(self, sample):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
img, mask = sample['image'], sample['mask']
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
img = transform(img)
return {'image': img, 'mask': mask}
# ...
class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = (size, size)
self.interpolation = interpolation
@staticmethod
def get_params(img):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
return i, j, w, w
def __call__(self, sample):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly cropped and resize image.
"""
i, j, h, w = self.get_params(sample['image'])
return resized_crop(sample, i, j, h, w, self.size, self.interpolation)
# ...
class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, sample):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < 0.5:
return hflip(sample)
return sample
class RandomVerticalFlip(object):
"""Vertically flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, sample):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < 0.5:
return vflip(sample)
return sample
def hflip(sample):
"""Horizontally flip the given PIL Image.
Args:
sample (PIL Image): Image to be flipped.
Returns:
PIL Image: Horizontall flipped image.
"""
img, mask = sample['image'], sample['mask']
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return {'image': img, 'mask': mask}
def vflip(sample):
"""Vertically flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Vertically flipped image.
"""
img, mask = sample['image'], sample['mask']
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
return {'image': img, 'mask': mask}
# ...
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return to_tensor(pic)
def to_tensor(sample):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
pic, mask = sample['image'], sample['mask']
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
if isinstance(img, torch.ByteTensor):
img = img.float()
return {'image': pic, 'mask': mask}
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
pic = torch.from_numpy(nppic)
return {'image': pic, 'mask': mask}
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
img = img.float() / 255.0
# img = img.float()
# handle PIL Image
if mask.mode == 'I':
img2 = torch.from_numpy(np.array(mask, np.int32, copy=False))
elif mask.mode == 'I;16':
img2 = torch.from_numpy(np.array(mask, np.int16, copy=False))
elif mask.mode == 'F':
img2 = torch.from_numpy(np.array(mask, np.float32, copy=False))
else:
img2 = torch.ByteTensor(torch.ByteStorage.from_buffer(mask.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if mask.mode == 'YCbCr':
nchannel = 3
elif mask.mode == 'I;16':
nchannel = 1
else:
nchannel = len(mask.mode)
img2 = img2.view(mask.size[1], mask.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img2 = img2.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img2, torch.ByteTensor):
img2 = img2.float()
return {'image': img, 'mask': img2}