Custom collate_fn with multiprocess data loading is stuck

I am using a custom collate_fn with a custom dataset.

The dataset:

class PrelimEmbedDataset(Dataset):
    '''Dataset of preliminary embeddings'''

    def __init__(
        self,
        split: str,
        datadir: str,
        ckpt_path: str,
        encoder_type: str = 'roberta',
        reduction: str = 'pool'
    ) -> None:
        '''Initializer

        Parameters
        ----------
        split
            'train' or 'test' or 'valid' splits
        datadir
            The data directory
        ckpt_path
            The path to the preliminary encoder model checkpoint
        encoder_type
            The preliminary model type
        reduction
            How to average out the output of the encoder model
        '''
        if encoder_type == 'roberta':
            self.tokenizer = RobertaTokenizerFast.from_pretrained(ckpt_path)
            self.encoder = RobertaForMaskedLM.from_pretrained(
                ckpt_path).roberta
        else:
            raise NotImplementedError(
                'Only "roberta" is supported for "transformer_type"'
            )
        self.data = pd.read_pickle(os.path.join(datadir, f'{split}.pkl'))
        assert reduction in ['pool', 'mean'], (
            'Only "pool" and "mean" options supported for "reduction"')
        self.reduction = reduction

    def __len__(self,) -> int:
        '''Magic method to return length'''
        return len(self.data.index)

    def __getitem__(self, idx: int) -> torch.Tensor:
        '''Load and return the input tensor

        Parameters
        ----------
        idx
            The index
        '''
        row = self.data.iloc[idx]
        input_ = []
        with torch.no_grad():
            for chunk in row.conversation:
                tokenizer_out = self.tokenizer(chunk, return_tensors='pt')
                encoder_out = self.encoder(**tokenizer_out)
                if self.reduction == 'mean':
                    encoder_out = encoder_out.last_hidden_state.squeeze().mean(
                        0)
                elif self.reduction == 'pool':
                    encoder_out = encoder_out.last_hidden_state.squeeze()[0]
                else:
                    raise NotImplementedError(
                        'Only "pool" and "mean" options supported '
                        'for "reduction"')
                input_.append(encoder_out)
            input_ = torch.stack(input_)
        return {
            'input': input_,  # TxC
            'idx': idx,
        }

Here the row.conversation column is a list of sentences.

The collater function:

def same_shape_collater(inputs: List[torch.Tensor]):
    '''Pad/truncate inputs to the same shape so that they can be batched

    Parameters
    ----------
    inputs
        The list of input tensors to batch into a single batch-tensor
    '''
    idxs = [input_['idx'] for input_ in inputs]
    inputs = [input_['input'] for input_ in inputs]
    embed_dim = inputs[0].shape[1]
    seq_lengths = [input_.shape[0] for input_ in inputs]
    target_size = max(seq_lengths)
    collated_inputs = inputs[0].new_zeros(
        len(inputs), target_size, embed_dim)
    padding_mask = torch.BoolTensor(collated_inputs.shape).fill_(False)
    for i, (input_, length) in enumerate(zip(inputs, seq_lengths)):
        diff = target_size - length
        if diff == 0:
            collated_inputs[i] = input_
        else:
            collated_inputs[i] = torch.vstack(
               [input_, input_.new_full((diff, embed_dim), 0.0)]
            )
            padding_mask[i, -diff:, :] = True
    return {
        'idx': idxs,
        'features': collated_inputs,  # BxTxC
        'padding_mask': padding_mask  # BxTxC
    }

Then I am using the torch.utils.data.DataLoader to load the samples as follows:

train_dataset = PrelimEmbedDataset(
        split='train', datadir=args.datadir, ckpt_path=args.prelim_model,
        encoder_type=args.prelim_type, reduction=args.prelim_reduction
    )
train_dataloader = DataLoader(train_dataset, shuffle=True,
                              batch_size=8,
                              num_workers=8,
                              collate_fn=same_shape_collater)

for sample in train_dataloader:
    # train code

When I run this, the dataloader keeps on loading the data samples without calling collate function. But, when I make num_workers=0, then everything works okay. Can someone help me with this?

1 Like

Bonjour! i met the same question, and don’t know how to fix this. sad.

What makes you think that your collate function is not called? One of the things I discovered is that printing statements on different processes might not land in your console, but the work is still done. If you compare the output of the collate function on single processing compared to multiprocessing, what does it look like? What is the difference?