This forum has been very helpful to me and I am so grateful to all of you who spend time to help and strengthen this community. So one thing I’ve noticed is that most topics are about how to use things or fix things that don’t work and I think one way which would also be incredibly beneficial (at least to me) is to get feedback at times on how to improve parts of code that work but that could be improved. Anyways I thought I’d ask and if you’re willing then I would be very thankful.
I’ve spent some time today to try and understand how to write custom datasets when dealing with text and the tutorial is a great start but it doesn’t cover the topic when working with text (using collate_fn for padding etc). I’ve managed to get something that processes the data in the way that I want and loads the data but I’m quite sure there are ways to make the data loading faster, or simply things that I’ve missed that could improve the quality of the code.
The Data:
I have a caption dataset with a csv file containing image_id and the corresponding caption to the image. I want to load the image and numericalize the caption so that it can be sent in to a sequence model but I need a collate_fn
function so I can pad the caption sentences and also I need to build a vocabulary.
The Code:
The CaptionDataset class looks like this
class CaptionDataset(Dataset):
def __init__(self, root_dir, captions_file, transform=None, freq_threshold=1):
self.root_dir = root_dir
self.df = pd.read_csv(captions_file)
self.transform = transform
# Get img, caption column
self.imgs = self.df["image"]
self.captions = self.df["caption"]
# Initialize vocabulary class and build vocab
self.vocab = Vocabulary(freq_threshold)
self.vocab.build_vocabulary(self.captions.tolist())
def __len__(self):
return len(self.df)
def __getitem__(self, index):
caption = self.captions[index]
img_id = self.imgs[index]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
if self.transform is not None:
img = self.transform(img)
numericalized_caption = [self.vocab.stoi["<SOS>"]]
numericalized_caption += self.vocab.numericalize(caption)
numericalized_caption.append(self.vocab.stoi["<EOS>"])
return img, torch.tensor(numericalized_caption)
where in the DataLoader
I use the following collate function:
class MyCollate:
def __init__(self, pad_idx):
self.pad_idx = pad_idx
def __call__(self, batch):
# If you want to use pack_padded_sequence you should also get
# all the lengths of the sequence here. I won't use it so I
# skip it
imgs = [item[0].unsqueeze(0) for item in batch]
imgs = torch.cat(imgs, dim=0)
targets = [item[1] for item in batch]
targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
return imgs, targets
And I think this below is skippable but could be of interest:
class Vocabulary:
def __init__(self, freq_threshold):
# Integer to String, and String to Integer
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
self.freq_threshold = freq_threshold
def build_vocabulary(self, sentence_list):
frequencies = {}
idx = 4
for sentence in sentence_list:
for word in tokenizer_eng(sentence):
# Simple way of building frequencies
if word not in frequencies:
frequencies[word] = 1
else:
frequencies[word] += 1
# If it ever reaches a frequency of the threshold
# we add it to our vocabulary (stoi & itos)
if frequencies[word] == self.freq_threshold:
self.stoi[word] = idx
self.itos[idx] = word
idx += 1
def numericalize(self, text):
tokenized_text = tokenizer_eng(text)
return [
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
for token in tokenized_text
]
The Question
Does this structure look good, are there things you would have done differently for efficiency etc?