How to split the following batch
by modality
and forward the splits in different encoders?
desc_encoder = DescEncoder(...)
img_encoder = ImgEncoder(...)
batch = {
'idx': torch.tensor([0, 1, 2, 3]),
'modality': ['desc', 'img', 'desc', 'img'],
'sample': torch.tensor([
[ 101, 35596, 174818, ..., 0, 0, 0],
[ 101, 322806, 347627, ..., 0, 0, 0],
[ 101, 35596, 174818, ..., 0, 0, 0],
[ 101, 35596, 174818, ..., 0, 0, 0]
])
}
desc_sample = # only samples with desc modality
img_sample = # only samples with img modality
desc_rpr = desc_encoder(desc_sample)
img_rpr = img_encoder(img_sample)