While building a custom data-loader, I need different kinds of elastic transforms for images and mask(solving segmentation problem), I defined both elastic transform in the class which is giving me Getting Fatal Python Error: Cannot recover from stack overflow
Following is my data-loader-
import torch.utils.data as data
class DataLoaderSegmentation(data.Dataset):
def __init__(self,folder_path,transform = None):
super(DataLoaderSegmentation, self).__init__()
self.img_files = glob.glob(os.path.join(folder_path,'images','*.tif'))
self.mask_files = glob.glob(os.path.join(folder_path,'mask','*.bmp'))
self.transforms = transform
#for img_path in img_files:
# self.mask_files.append(os.path.join(folder_path,'masks',os.path.basename(img_path))
def mask_to_class(self,mask):
target = torch.from_numpy(mask)
h,w = target.shape[0],target.shape[1]
masks = torch.empty(h, w, dtype=torch.long)
colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
target = target.permute(2, 0, 1).contiguous()
mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
for k in mapping:
idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
validx = (idx.sum(0) == 3)
masks[validx] = torch.tensor(mapping[k], dtype=torch.long)
return masks
def elastic_transform_nearest(self,image, alpha=1000, sigma=20, spline_order=0, mode='nearest', random_state=np.random):
image = np.array(image)
assert image.ndim == 3
shape = image.shape[:2]
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
result = np.empty_like(image)
for i in range(image.shape[2]):
result[:, :, i] = map_coordinates(
image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
result = Image.fromarray(result)
return result
def elastic_transform_bilinear(image, alpha=1000, sigma=20, spline_order=1, mode='nearest', random_state=np.random):
image = np.array(image)
assert image.ndim == 3
shape = image.shape[:2]
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
result = np.empty_like(image)
for i in range(image.shape[2]):
result[:, :, i] = map_coordinates(
image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
result = Image.fromarray(result)
return result
def transform(self,image,mask):
i, j, h, w = transforms.RandomCrop.get_params(
image, output_size=(512, 512))
image = TF.crop(image, i, j, h, w)
mask = TF.crop(mask, i, j, h, w)
# Random horizontal flipping
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
image = TF.rotate(image,90)
mask = TF.rotate(mask,90)
image = TF.rotate(image,180)
mask = TF.rotate(mask,180)
image = TF.rotate(image,270)
mask = TF.rotate(mask,270)
return image, mask
def __getitem__(self, index):
img_path = self.img_files[index]
mask_path = self.mask_files[index]
data = Image.open(img_path)
data = self.elastic_transform_bilinear(data)
label = self.elastic_transform_nearest(label)
data,label = self.transform(data,label)
label = np.array(label)
data = np.array(data)
mask = self.mask_to_class(label)
if transforms is not None:
data = self.transforms(data)
return data,mask
def __len__(self):
return len(self.img_files)
If I run the code without elastic transforms
, it is running fine, also when I am running them using lambda transform and using them with transform, then it is working fine but as I don’t know how to use different function for images and mask, I am resorting to my above approach-
data_transforms = transforms.Compose([
transforms.Lambda(gaussian_blur),
transforms.Lambda(elastic_transform),
transforms.ToTensor(),
transforms.Normalize(mean=train_mean, std=train_std)
])