How to efficiently filter out classes

Say I have a custom Dataset instance with 200 classes and a million samples.
I want to create a custom Subset instance containing only the samples whose label is in a certain list of selected labels.

I tried to do this by looping through the indices as follows:

for idx in self.indices:
    datapoint_label: int = dataset[idx][1]
    if dataset.label_to_class[datapoint_label] in chosen_classes:
        self.filtered_chosen_indices += [idx]

This works but it is way too slow. I would expect Pytorch to have an efficient method to do something of this sort. Does it?

Note: the attribute label_to_class from the custom dataset is a dictionary Dict[int, str] that attaches the class name to the corresponding integer label.