Hey! Sorry to revive such an old topic but recently I’ve been redoing my implementation of a project to try to squeeze out the max amount of performance out of it and decided to use h5 file for datasets.
In your recommendations you mention to use a DataLoader
with a batch_sampler
because random access to hdf5 files is slow. What would be the difference between using a custom batch_sampler
to grab random samples, and just using the default sampler
with shuffle=True
?
For context, this is my H5Dataset implementation:
class H5Dataset(torch.utils.data.Dataset):
def __init__(self, path, split, mode):
self.file_path = path
self.dataset = None
self.split = split
self.mode = mode
with h5py.File(self.file_path, 'r') as file:
if self.split == "pixel_values":
self.dataset_len = len(file[self.split])
else:
assert len(file[self.split]["img_id"]) == len(file[self.split]["category"]) == len(
file[self.split]["category"]) == len(file[self.split]["attention_mask"]) == len(file[self.split]["input_ids"]), "non matching number of entries in .h5 file."
self.dataset_len = len(file[self.split]["img_id"])
self.categories = [category.decode("utf-8") for category in np.unique(file[self.split]["category"])]
def __getitem__(self, idx):
if self.dataset is None:
self.dataset = h5py.File(self.file_path, 'r')
output = {}
output["attention_mask"] = self.dataset[self.split + "/attention_mask"][idx]
output["category"] = self.dataset[self.split + "/category"][idx].decode("utf-8")
output["img_id"] = self.dataset[self.split + "/img_id"][idx]
output["input_ids"] = self.dataset[self.split + "/input_ids"][idx]
output["label"] = self.dataset[self.split + "/label"][idx]
if self.mode == "baseline":
output["pixel_values"] = self.dataset["pixel_values"][output["img_id"]][4]
elif self.mode == "patching":
output["pixel_values"] = self.dataset["pixel_values"][output["img_id"]]
return output
def __len__(self):
return self.dataset_len
and in my Trainer implementation I keep the dataset related stuff (Datasets and DataLoaders) as follows:
self.train_dataset = H5Dataset(os.path.join("datasets", "dataset_folder", "dataset_file.h5"), "train", model)
self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size,shuffle=True, num_workers=4)
self.validation_dataset = H5Dataset(os.path.join("datasets", "dataset_folder", "dataset_file.h5"), "validation", model)
self.validation_loader = torch.utils.data.DataLoader(self.validation_dataset, batch_size=self.batch_size,shuffle=False, num_workers=4)
self.test_dataset = H5Dataset(os.path.join("datasets", "dataset_folder", "dataset_file.h5"), "test", model)
self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size,
shuffle=False, num_workers=4)
Any tips would be greatly appreciated as I am quite new to this