I am using a custom collate_fn
with a custom dataset.
The dataset:
class PrelimEmbedDataset(Dataset):
'''Dataset of preliminary embeddings'''
def __init__(
self,
split: str,
datadir: str,
ckpt_path: str,
encoder_type: str = 'roberta',
reduction: str = 'pool'
) -> None:
'''Initializer
Parameters
----------
split
'train' or 'test' or 'valid' splits
datadir
The data directory
ckpt_path
The path to the preliminary encoder model checkpoint
encoder_type
The preliminary model type
reduction
How to average out the output of the encoder model
'''
if encoder_type == 'roberta':
self.tokenizer = RobertaTokenizerFast.from_pretrained(ckpt_path)
self.encoder = RobertaForMaskedLM.from_pretrained(
ckpt_path).roberta
else:
raise NotImplementedError(
'Only "roberta" is supported for "transformer_type"'
)
self.data = pd.read_pickle(os.path.join(datadir, f'{split}.pkl'))
assert reduction in ['pool', 'mean'], (
'Only "pool" and "mean" options supported for "reduction"')
self.reduction = reduction
def __len__(self,) -> int:
'''Magic method to return length'''
return len(self.data.index)
def __getitem__(self, idx: int) -> torch.Tensor:
'''Load and return the input tensor
Parameters
----------
idx
The index
'''
row = self.data.iloc[idx]
input_ = []
with torch.no_grad():
for chunk in row.conversation:
tokenizer_out = self.tokenizer(chunk, return_tensors='pt')
encoder_out = self.encoder(**tokenizer_out)
if self.reduction == 'mean':
encoder_out = encoder_out.last_hidden_state.squeeze().mean(
0)
elif self.reduction == 'pool':
encoder_out = encoder_out.last_hidden_state.squeeze()[0]
else:
raise NotImplementedError(
'Only "pool" and "mean" options supported '
'for "reduction"')
input_.append(encoder_out)
input_ = torch.stack(input_)
return {
'input': input_, # TxC
'idx': idx,
}
Here the row.conversation
column is a list of sentences.
The collater function:
def same_shape_collater(inputs: List[torch.Tensor]):
'''Pad/truncate inputs to the same shape so that they can be batched
Parameters
----------
inputs
The list of input tensors to batch into a single batch-tensor
'''
idxs = [input_['idx'] for input_ in inputs]
inputs = [input_['input'] for input_ in inputs]
embed_dim = inputs[0].shape[1]
seq_lengths = [input_.shape[0] for input_ in inputs]
target_size = max(seq_lengths)
collated_inputs = inputs[0].new_zeros(
len(inputs), target_size, embed_dim)
padding_mask = torch.BoolTensor(collated_inputs.shape).fill_(False)
for i, (input_, length) in enumerate(zip(inputs, seq_lengths)):
diff = target_size - length
if diff == 0:
collated_inputs[i] = input_
else:
collated_inputs[i] = torch.vstack(
[input_, input_.new_full((diff, embed_dim), 0.0)]
)
padding_mask[i, -diff:, :] = True
return {
'idx': idxs,
'features': collated_inputs, # BxTxC
'padding_mask': padding_mask # BxTxC
}
Then I am using the torch.utils.data.DataLoader
to load the samples as follows:
train_dataset = PrelimEmbedDataset(
split='train', datadir=args.datadir, ckpt_path=args.prelim_model,
encoder_type=args.prelim_type, reduction=args.prelim_reduction
)
train_dataloader = DataLoader(train_dataset, shuffle=True,
batch_size=8,
num_workers=8,
collate_fn=same_shape_collater)
for sample in train_dataloader:
# train code
When I run this, the dataloader keeps on loading the data samples without calling collate function. But, when I make num_workers=0
, then everything works okay. Can someone help me with this?