Issues with torch.utils.data.random_split()

Hi,
I’m not getting split when I use torch.utils.data.random_split. Here’s my code:

class DeviceLoader(Dataset):

def __init__(self, root_dir, train=True, transform=None):
    self.file_path = root_dir
    self.train = train
    self.transform = transform
    self.file_names = ['%s/%s'%(root,file) for root,_,files in os.walk(root_dir) for file in files]
    self.len = len(self.file_names)
    self.labels = {'BP_Raw_Images':0, 'DT_Raw_Images':1, 'GL_Raw_Images':2, 'PO_Raw_Images':3, 'WS_Raw_Images':4}
    
def __len__(self):
    return(len(self.file_names))

def __getitem__(self, idx):
    file_name = self.file_names[idx]
    device = file_name.split('/')[5]
    img = self.pil_loader(file_name)
    if(self.transform):
        img = self.transform(img)
    cat = self.labels[device]            
    if(self.train):
        return(img, cat)
    else:
        return(img, file_name)

full_data = DeviceLoader(root_dir=’/kaggle/input/devices/dataset/’, transform=transforms, train=True)
train_size = int(0.7*len(full_data))
val_size = len(full_data) - train_size
train_data, val_data = torch.utils.data.random_split(full_data,[train_size,val_size])

I get correct numbers for train_size and val_size, but when I do random_split, both train_data and val_data get full_data. There is no split happening.

Please help me fix this issue.
@ptrblck

Hi,

It looks weird, because I just ran the code using random TensorDataset and it is ok when you can get the length of your dataset, trainset, and valset correctly as you have achieved this too.

import torch
from torch.utils.data.dataset import TensorDataset, random_split

init_dataset = TensorDataset(
    torch.randn(100, 3, 5, 5),
)

train_size = int(0.7*len(init_dataset))
val_size = len(init_dataset) - train_size
train_data, val_data = torch.utils.data.random_split(init_dataset,[train_size,val_size])

What is the length of your dataset? I would like to do some experiments with your defined class.

Total length is 3001 @Nikronic

I copied your code and ran, it works fine. Could you show the version of your environment packages?

Here is a script you can use: https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py

Just run the code using python collect_env.py and paste output here.

I’m sorry. I was checking train_data.dataset.len instead of len(train_data). Split is happening correctly. No issues.