so the format of a custom dataset should be like fllowing:
import torch
from torch.utils import data
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, list_IDs, labels):
'Initialization'
self.labels = labels
self.list_IDs = list_IDs
def __len__(self):
'Denotes the total number of samples'
return len(self.list_IDs)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
ID = self.list_IDs[index]
# Load data and get label
X = torch.load('data/' + ID + '.pt')
y = self.labels[ID]
return X, y
I like to have have ID information in the output in addition to x and y. So i did return X, y, ID
, but now when I do
All data returned by a dataset needs to be a tensor, if you want to use the default collate_fn of the Dataloader. You have two options: write a custom collate function and pass it to the dataloader or wrap your ID inside a tensor (which is simpler I guess) and unwrap it outside the dataloader.
Ah sorry, I implied your ID would be an integer. You cannot wrap a string to a tensor. I could think of some ways to achieve something like that, but it would not be very pytorch-like. If you are interested in these Ways you can PM me.
that’s what I thought about too. I also thought about wrapping the loader itself, but one would have to define a new iterator for this. I proposed another method, and if this method works (currently waiting for verification), I will post it here later on.
I think that another way to do this without building a custom collate function would be not to return the ID but directly the idx within the getitem implementation (which is numerical and can be treated in batches by the default collate function). Something like:
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, list_IDs, labels, retrun_idx: bool = False):
'Initialization'
self.labels = labels
self.list_IDs = list_IDs
self.return_idx = return_idx
def __len__(self):
'Denotes the total number of samples'
return len(self.list_IDs)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
ID = self.list_IDs[index]
# Load data and get label
X = torch.load('data/' + ID + '.pt')
y = self.labels[ID]
if self.return_idx:
return X, y, index
return X, y
Then, you access externally to the list_IDs with the batch indexes