Training crashes even after using numpy mmapped arrays in Dataset

I am trying to load the SQUAD question answering dataset. Since loading the dataset into pytorch was going out of memory, I used a custom dataset to use numpy mmmaped arrays and lazily access only the required rows in the getitem function. For reference my dataset is as below :

class SQuAD(data.Dataset):
    """Stanford Question Answering Dataset (SQuAD).

    Each item in the dataset is a tuple with the following entries (in order):
        - context_idxs: Indices of the words in the context.
            Shape (context_len,).
        - context_char_idxs: Indices of the characters in the context.
            Shape (context_len, max_word_len).
        - question_idxs: Indices of the words in the question.
            Shape (question_len,).
        - question_char_idxs: Indices of the characters in the question.
            Shape (question_len, max_word_len).
        - y1: Index of word in the context where the answer begins.
            -1 if no answer.
        - y2: Index of word in the context where the answer ends.
            -1 if no answer.
        - id: ID of the example.

        data_path (str): Path to .npz file containing pre-processed dataset.
        use_v2 (bool): Whether to use SQuAD 2.0 questions. Otherwise only use SQuAD 1.1.
    def __init__(self, data_path, use_v2=True):
        super(SQuAD, self).__init__()

        obj = np.load(data_path)'context_idxs.npy')'context_char_idxs.npy')'ques_idxs.npy')'ques_char_idxs.npy')'y1s.npy')'y2s.npy')'ids.npy')

        self.context_idxs_memmap = np.load('context_idxs.npy', mmap_mode='r')
        self.context_char_idxs_memmap = np.load('context_char_idxs.npy', mmap_mode='r')
        self.ques_idxs_memmap = np.load('ques_idxs.npy', mmap_mode='r')
        self.ques_char_idxs_memmap = np.load('ques_char_idxs.npy', mmap_mode='r')
        self.y1s = np.load('y1s.npy', mmap_mode='r')
        self.y2s = np.load('y2s.npy', mmap_mode='r')
        batch_size, c_len, w_len = self.context_char_idxs_memmap.shape
        print('-----------shape of memmapped array is ', batch_size, c_len, w_len)
        self.w_len = w_len
        self.ids = np.load('ids.npy', mmap_mode='r')

    def __getitem__(self, idx):
        # idx = self.valid_idxs[idx]
        example = (torch.from_numpy(np.concatenate([[1], self.context_idxs_memmap[idx]])).long(),
                   torch.from_numpy(np.concatenate([np.ones((1, self.w_len)), self.context_char_idxs_memmap[idx]], axis=0)).long(),
                   torch.from_numpy(np.concatenate([[1], self.ques_idxs_memmap[idx]])).long(),
                   torch.from_numpy(np.concatenate([np.ones((1, self.w_len)), self.ques_char_idxs_memmap[idx]], axis=0)).long(),

        return example

    def __len__(self):
        return self.ids.shape[0]

When I create a single instance of the SQuAD dataset (the training dataset), training proceeds fine without any errors. However when I create another instance of the SQuAD dataset for validation data (as shown below), the machine (I am using google colab) crashes even before training begins.

train_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    train_loader = data.DataLoader(train_dataset, ...)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    dev_loader = data.DataLoader(dev_dataset,...)

Any ideas on why this is happening? I thought the use of mmapped arrays would fix the out of memory problems.