I have question about pytorch Dataset class
what shoud I do when we need to skip n th element for training?
# when idx == 100, the data is ill-formed text.
# Thus, I want to skip this sample in training
def __getitem__(self, idx):
These case happens when I filter some text and save in MongoDB
When the filtered text is empty or too short, I want to skip the sample
You could return a constant tensor, which you could then filter out in the training loop, but note that this approach would lower the batch size and in the edge case you could also end up with a completely empty batch.
The better approach would be to remove these indices from the beginning.
If you could compute the invalid indices before (or in the
__init__ method), you could use a
valid_idx list, which would then return only the samples, which should be returned:
self.valid_idx = [0, 2, 3, 5, 6, 8, ...] # calculate only valid indices or filter out invalid ones
def __getitem__(self, index):
idx = self.valid_idx[index]
data = self.data[idx]
Actually, I made MongoDB wrapper for collecting several collections
And, the wrapper is connected to a custom Dataset class
Thus, maybe I can build some valid index for the collections in the mongodb wrapper, which is pre-built index before training loops checking whether the document is valid or not
I got some hints from the answer. Thank you.