DataLoader not working with IterableDataset and Shuffler

class SeqDataset(IterableDataset):
    def __init__(
        self,
        input_files,
        max_length,
        min_length,
        masked_lm_prob,
        max_predictions_per_seq,
        rng,
        tokenizer,
    ) -> None:
        super(SeqDataset).__init__()
        if not isinstance(input_files, list):
            input_files = [input_files]
        for input_file in input_files:
            self.filename = input_file
            self.max_length = max_length
            self.min_length = min_length
            self.masked_lm_prob = masked_lm_prob
            self.max_predictions_per_seq = max_predictions_per_seq
            self.rng = rng
            self.tokenizer = tokenizer
    def mask_und_pad(self, tokenized_seq, vocab):
        segment_ids = [0] * len(tokenized_seq)
        (
            tokenized_seq,
            masked_lm_positions,
            masked_lm_labels,
        ) = du.create_masked_lm_predictions(
            tokenized_seq,
            self.masked_lm_prob,
            self.max_predictions_per_seq,
            vocab,
            self.rng,
        )
        tokenized_seq = self.tokenizer.convert_tokens_to_ids(tokenized_seq)
        masked_lm_labels = self.tokenizer.convert_tokens_to_ids(masked_lm_labels)
        # return tokenized_seq, segment_ids, masked_lm_positions, masked_lm_labels

        input_ids = tokenized_seq
        segment_ids = segment_ids
        input_mask = [1] * len(input_ids)
        while len(input_ids) < self.max_length:
            # pad with 0s
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        masked_lm_positions = masked_lm_positions
        masked_lm_ids = masked_lm_labels
        masked_lm_weights = [1.0] * len(masked_lm_ids)

        while len(masked_lm_positions) < self.max_predictions_per_seq:
            # pad with 0s
            masked_lm_positions.append(0)
            masked_lm_ids.append(0)
            masked_lm_weights.append(0.0)


        return (
        torch.LongTensor(input_ids),
        torch.LongTensor(segment_ids),
        torch.LongTensor(masked_lm_positions),
        torch.LongTensor(masked_lm_labels),
    )

    def preprocess(self, sequence):
        line = tokenization.convert_to_unicode(sequence)
        tokenized_seq = self.tokenizer.tokenize(line)
        vocab = list(self.tokenizer.vocab.keys())
        tokenized_seq = [tokenization.CLS_TOKEN] + tokenized_seq
        projected_length = len(tokenized_seq)
        if projected_length >= self.min_length and projected_length <= self.max_length:
            (
                input_ids,
                segment_ids,
                masked_lm_positions,
                masked_lm_labels
        ) = self.mask_und_pad(tokenized_seq, vocab)
        else:
            sequence = sequence[:self.max_length-2] #why 2?
            line = tokenization.convert_to_unicode(sequence)
            #if the length of a sequence is more than thw
            #maximum length, we remove the extra part
            #subtracted 1 for the class token which is added later
            #otherwise length will become more than max length
            tokenized_seq = self.tokenizer.tokenize(line)
            vocab = list(self.tokenizer.vocab.keys())
            tokenized_seq = [tokenization.CLS_TOKEN] + tokenized_seq
            (
                input_ids,
                segment_ids,
                masked_lm_positions,
                masked_lm_labels
        ) = self.mask_und_pad(tokenized_seq, vocab)
            
        return (
            torch.LongTensor(input_ids),
            torch.LongTensor(segment_ids),
            torch.LongTensor(masked_lm_positions),
            torch.LongTensor(masked_lm_labels)
        )
    def __iter__(self):

        worker_total_num = torch.utils.data.get_worker_info().num_workers # type: ignore
        worker_id = torch.utils.data.get_worker_info().id # type: ignore

        sequence = open(self.filename)
        mask = map(self.preprocess, sequence)
        mask = itertools.islice(mask, worker_id, None, worker_total_num)

        return mask

def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    if torch_seed >= 2**30:  # make sure torch_seed + workder_id < 2**32
        torch_seed = torch_seed % 2**30
    np.random.seed(torch_seed + worker_id)

  tokenizer = tokenization.FullTokenizer()
  max_length = 50
  min_length = 0
  masked_lm_prob = 0.15
  max_predictions_per_seq = 3
  rng = random.Random(1)
  filepath = "data/test1.txt"
  dataset1 = SeqDataset(
      filepath,
      max_length,
      min_length,
      masked_lm_prob,
      max_predictions_per_seq,
      rng,
      tokenizer,
  )
  shuffled_data1 = ShufflerIterDataPipe(dataset1, buffer_size=100) # type: ignore
  dataloader1 = DataLoader(shuffled_data1,
                           batch_size=1,
                           shuffle=True,
                           num_workers=2,
                           worker_init_fn=worker_init_fn
              )

I have a dataset containing a sequence of RNA residues, I am trying to develop a BERT model, which takes each sequence from a file and tries to learn from it. Since the dataset is pretty huge, I used an IterableDataset to load my data but in that case, shuffling was not possible so I used torch.utils.data.datapipes.iter.combinatorics.ShufflerIterDataPipe, but when I use this on a GPU it raises the following error -

Expected a ‘cuda’ device type for generator but found ‘cpu’.

Now to solve this issue I used the solutions provided here and here, but nothing seems to solve this. I would really be grateful for any help in this regard. I have invested two days in this but couldn’t find anything.

Thanks

  1. Shuffler should be used with IterDataPipe and not IterableDataset. Can you try that?
  2. You may want to have a look at this tutorial if you are using DataPipe with DataLoader. You can try it without a worker_init_fn as well.