Custom collate_fn with multiprocess data loading is stuck

I am using a custom collate_fn with a custom dataset.

The dataset:

class PrelimEmbedDataset(Dataset):
    '''Dataset of preliminary embeddings'''

    def __init__(
        self,
        split: str,
        datadir: str,
        ckpt_path: str,
        encoder_type: str = 'roberta',
        reduction: str = 'pool'
    ) -> None:
        '''Initializer

        Parameters
        ----------
        split
            'train' or 'test' or 'valid' splits
        datadir
            The data directory
        ckpt_path
            The path to the preliminary encoder model checkpoint
        encoder_type
            The preliminary model type
        reduction
            How to average out the output of the encoder model
        '''
        if encoder_type == 'roberta':
            self.tokenizer = RobertaTokenizerFast.from_pretrained(ckpt_path)
            self.encoder = RobertaForMaskedLM.from_pretrained(
                ckpt_path).roberta
        else:
            raise NotImplementedError(
                'Only "roberta" is supported for "transformer_type"'
            )
        self.data = pd.read_pickle(os.path.join(datadir, f'{split}.pkl'))
        assert reduction in ['pool', 'mean'], (
            'Only "pool" and "mean" options supported for "reduction"')
        self.reduction = reduction

    def __len__(self,) -> int:
        '''Magic method to return length'''
        return len(self.data.index)

    def __getitem__(self, idx: int) -> torch.Tensor:
        '''Load and return the input tensor

        Parameters
        ----------
        idx
            The index
        '''
        row = self.data.iloc[idx]
        input_ = []
        with torch.no_grad():
            for chunk in row.conversation:
                tokenizer_out = self.tokenizer(chunk, return_tensors='pt')
                encoder_out = self.encoder(**tokenizer_out)
                if self.reduction == 'mean':
                    encoder_out = encoder_out.last_hidden_state.squeeze().mean(
                        0)
                elif self.reduction == 'pool':
                    encoder_out = encoder_out.last_hidden_state.squeeze()[0]
                else:
                    raise NotImplementedError(
                        'Only "pool" and "mean" options supported '
                        'for "reduction"')
                input_.append(encoder_out)
            input_ = torch.stack(input_)
        return {
            'input': input_,  # TxC
            'idx': idx,
        }

Here the row.conversation column is a list of sentences.

The collater function:

def same_shape_collater(inputs: List[torch.Tensor]):
    '''Pad/truncate inputs to the same shape so that they can be batched

    Parameters
    ----------
    inputs
        The list of input tensors to batch into a single batch-tensor
    '''
    idxs = [input_['idx'] for input_ in inputs]
    inputs = [input_['input'] for input_ in inputs]
    embed_dim = inputs[0].shape[1]
    seq_lengths = [input_.shape[0] for input_ in inputs]
    target_size = max(seq_lengths)
    collated_inputs = inputs[0].new_zeros(
        len(inputs), target_size, embed_dim)
    padding_mask = torch.BoolTensor(collated_inputs.shape).fill_(False)
    for i, (input_, length) in enumerate(zip(inputs, seq_lengths)):
        diff = target_size - length
        if diff == 0:
            collated_inputs[i] = input_
        else:
            collated_inputs[i] = torch.vstack(
               [input_, input_.new_full((diff, embed_dim), 0.0)]
            )
            padding_mask[i, -diff:, :] = True
    return {
        'idx': idxs,
        'features': collated_inputs,  # BxTxC
        'padding_mask': padding_mask  # BxTxC
    }

Then I am using the torch.utils.data.DataLoader to load the samples as follows:

train_dataset = PrelimEmbedDataset(
        split='train', datadir=args.datadir, ckpt_path=args.prelim_model,
        encoder_type=args.prelim_type, reduction=args.prelim_reduction
    )
train_dataloader = DataLoader(train_dataset, shuffle=True,
                              batch_size=8,
                              num_workers=8,
                              collate_fn=same_shape_collater)

for sample in train_dataloader:
    # train code

When I run this, the dataloader keeps on loading the data samples without calling collate function. But, when I make num_workers=0, then everything works okay. Can someone help me with this?