Train_loc = ‘./Mehsana/Train_data/’
Val_loc = ‘./Mehsana/Val_data/’
X = glob.glob(Train_loc+‘Images/’+‘img’)
Y = glob.glob(Train_loc +‘labels/’+’.img’)
X_train, Y_train, X_test, Y_test = X[:2000], Y[:2000], X[2000:], Y[2000:]
del X,Y
class CustomDataset(Dataset):
def init(self, image_paths, target_paths, train=True): # initial logic happens like transform
self.image_paths = image_paths
self.target_paths = target_paths
self.transforms = transforms.Compose([
#transforms.RandomHorizontalFlip(p=0.5),
#transforms.RandomVerticalFlip(p=0.5),
transforms.Resize((224,224)), transforms.ToTensor(),
transforms.Normalize(
mean= [688.8383261571815, 890.7605253921796, 1012.9213980828494, 3228.524152857917],
std=[238.88659552814724, 267.57694692565684, 347.344288531034, 689.5044531578629])])
def __getitem__(self, index):
img =gdal.Open(self.image_paths[index],gdal.GA_ReadOnly)
img= img.ReadAsArray()
print(img.shape)
img = np.moveaxis(img,0,-1)
mask = gdal.Open(self.target_paths[index],gdal.GA_ReadOnly)
print(np.unique(mask))
t_image = self.transforms(img)
return t_image, mask
def __len__(self): # return count of sample we have
return len(self.image_paths)
train_dataset = CustomDataset( X_train, Y_train, train=True)
train_dataset = CustomDataset(train_image_paths, train_mask_paths, train=True)
I get this error when I am trying to generate a custom dataset in pytorch for Unet model. My images are of size (250,250,4).