Hi,
I’m not getting split when I use torch.utils.data.random_split. Here’s my code:
class DeviceLoader(Dataset):
def __init__(self, root_dir, train=True, transform=None):
self.file_path = root_dir
self.train = train
self.transform = transform
self.file_names = ['%s/%s'%(root,file) for root,_,files in os.walk(root_dir) for file in files]
self.len = len(self.file_names)
self.labels = {'BP_Raw_Images':0, 'DT_Raw_Images':1, 'GL_Raw_Images':2, 'PO_Raw_Images':3, 'WS_Raw_Images':4}
def __len__(self):
return(len(self.file_names))
def __getitem__(self, idx):
file_name = self.file_names[idx]
device = file_name.split('/')[5]
img = self.pil_loader(file_name)
if(self.transform):
img = self.transform(img)
cat = self.labels[device]
if(self.train):
return(img, cat)
else:
return(img, file_name)
full_data = DeviceLoader(root_dir=’/kaggle/input/devices/dataset/’, transform=transforms, train=True)
train_size = int(0.7*len(full_data))
val_size = len(full_data) - train_size
train_data, val_data = torch.utils.data.random_split(full_data,[train_size,val_size])
I get correct numbers for train_size and val_size, but when I do random_split, both train_data and val_data get full_data. There is no split happening.
Please help me fix this issue.
@ptrblck