So I have a very strange issue.
I have a DataSet that has labels between 0 and 100 (101 classes). I split my dataset internally with train being first 91 classes and validation being final 10.
When I pass this dataset to a DataLoader (with or without a sampler) it returns labels that are outside the label set, for example 112, 105 etc… I am very confused as to how this is happening as I tried to simplify things as much as possible and it still happens.
When I look at my dataset object and use the .__getitem__(index)
, it returns the correct labels. So whatever is happening it must be happening inside the DataLoader.
This is my Dataset class which pulls from HDF5 and returns the image X and label y.
class FoodData(torch.utils.data.Dataset):
def __init__(self, h5_path, train=True, transform=None):
"""
Inputs:
h5_path (Str): specifying path of HDF5 file to load
train (Bool): whether to build training dataset (first 90 cats) or
to build validation dataset.
transform (torch transforms): if None is skipped, otherwise torch
applies transforms
"""
self.h5_path = h5_path
self.train = train
self.train_idx = None
self.train_cutoff = 0
self.labels = None # this isn't too intensive
self.transform = transform
# load the data file into memory
self.data = h5py.File(self.h5_path, "r")
self.labels = np.array(self.data["labels"])
# get labels for the train/test split, first 91 cat for train
self.train_idx = np.argmax(self.labels > 90)
if self.train:
self.labels = self.labels[:self.train_idx]
else:
self.train_cutoff = self.train_idx
self.labels = self.labels[self.train_idx:]
self.length = self.labels.shape[0]
def __getitem__(self, index):
"""
Method for pulling images and labels from the initialized HDF5 file
"""
X = self.data["images"][index + self.train_cutoff]
y = self.data["labels"][index + self.train_cutoff]
if self.transform is not None:
X = Image.fromarray(X)
X = self.transform(X)
#y = self.labels[index]
return X, y
def __len__(self):
return self.length
I instantiate it as such;
train_dataset = FoodData("data.h5", train=True)
If I then do;
for i in tqdm(range(100999)):
if train_dataset.__getitem__(i)[1] > 101:
print("That's not right")
break
nothing happens, nor does it when I inspect the labels manually.
However if I do;
valloader = DataLoader(
dataset=train_dataset,
batch_size=128,
num_workers=3)
for data in valloader:
im, lab = data
if np.sum(lab.numpy() > 101) > 0:
print(lab)
break
then it outputs something like
tensor([ 79, 48, 63, 83, 42, 27, 38, 76, 26, 15, 74, 11, 35, 60,
82, 39, 48, 10, 46, 32, 89, 0, 4, 18, 54, 55, 87, 0,
20, 16, 38, 54, 52, 58, 43, 71, 23, 49, 90, 53, 69, 32,
16, 0, 83, 0, 0, 72, 0, 53, 9, 80, 66, 5, 35, 59,
13, 13, 32, 65, 84, 45, 54, 61, 83, 0, 55, 62, 65, 40,
91, 92, 38, 42, 78, 87, 85, 5, 25, 8, 55, 3, 48, 26,
33, 23, 81, 31, 0, 40, 74, 13, 84, 89, 56, 15, 85, 59,
47, 2, 60, 67, 122, 59, 35, 83, 53, 46, 17, 73, 9, 71,
2, 55, 58, 65, 78, 7, 57, 90, 31, 1, 14, 16, 68, 49,
4, 28], dtype=torch.uint8)
where obviously we have the label 122 which shouldn’t exist. Please let me know if you have run across anything similar.
Interestingly if I turn off shuffle, I get outputs such as this;
tensor([ 0, 0, 0, 0, 0, 0, 0, 66, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 66, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 66, 0, 0, 0, 0, 0,
0, 66, 66, 0, 0, 0, 0, 0, 0, 0, 0, 125, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 66, 0, 0, 66, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 66, 1, 66, 66, 0, 0, 0, 0, 0,
0, 0], dtype=torch.uint8)
indicating that something strange is turning my zeros into 66, 125 etc…
Further edit… I am going insane, if num_workers = 0
this issue does not occur!!! This must be related to the fact I am preloading the HDF5 file or something…
Edit: Final code that worked for me;
class FoodData(torch.utils.data.Dataset):
def __init__(self, h5_path, train_idx=90999, transform=None):
"""
Inputs:
h5_path (Str): specifying path of HDF5 file to load
train_idx int: index where to stop train set, after that it's the
holdout set. For Food-101 that's 90999.
transform (torch transforms): if None is skipped, otherwise torch
applies transforms
"""
self.h5_path = h5_path
self.train_idx = train_idx # manually found this
self.transform = transform
def __getitem__(self, index):
"""
Method for pulling images and labels from the initialized HDF5 file
"""
with h5py.File(self.h5_path, "r") as f:
X = f["images"][index]
y = f["labels"][index]
if self.transform is not None:
X = Image.fromarray(X)
X = self.transform(X)
return X, y
def __len__(self):
with h5py.File(self.h5_path, "r") as f:
return f["labels"].shape[0]