Could I get feedback on a Custom Dataset Loader?

This forum has been very helpful to me and I am so grateful to all of you who spend time to help and strengthen this community. So one thing I’ve noticed is that most topics are about how to use things or fix things that don’t work and I think one way which would also be incredibly beneficial (at least to me) is to get feedback at times on how to improve parts of code that work but that could be improved. Anyways I thought I’d ask and if you’re willing then I would be very thankful.

I’ve spent some time today to try and understand how to write custom datasets when dealing with text and the tutorial is a great start but it doesn’t cover the topic when working with text (using collate_fn for padding etc). I’ve managed to get something that processes the data in the way that I want and loads the data but I’m quite sure there are ways to make the data loading faster, or simply things that I’ve missed that could improve the quality of the code.

The Data:

I have a caption dataset with a csv file containing image_id and the corresponding caption to the image. I want to load the image and numericalize the caption so that it can be sent in to a sequence model but I need a collate_fn function so I can pad the caption sentences and also I need to build a vocabulary.

The Code:

The CaptionDataset class looks like this

class CaptionDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=1):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption column
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary class and build vocab
        self.vocab = Vocabulary(freq_threshold)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img =, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)

        return img, torch.tensor(numericalized_caption)

where in the DataLoader I use the following collate function:

class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        # If you want to use pack_padded_sequence you should also get
        # all the lengths of the sequence here. I won't use it so I
        # skip it
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs =, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        return imgs, targets

And I think this below is skippable but could be of interest:

class Vocabulary:
    def __init__(self, freq_threshold):
        # Integer to String, and String to Integer
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in tokenizer_eng(sentence):
                # Simple way of building frequencies
                if word not in frequencies:
                    frequencies[word] = 1
                    frequencies[word] += 1

                # If it ever reaches a frequency of the threshold
                # we add it to our vocabulary (stoi & itos)
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text

The Question :slight_smile:

Does this structure look good, are there things you would have done differently for efficiency etc?

You can speed up the process using:

create a data_prefetcher (follow the example here)


The code looks good and seems to be well structured, so that I think I understand what is going on even without running it.
From my personal point of view, a clean and modifiable code is really important if you are planning to run some experiments. While the performance is certainly important as well, I think the functionality comes first and you can try to optimizer certain parts of the code later (especially after isolating the bottlenecks). :wink:

1 Like