Apply different encoders to a batch of tensors

How to split the following batch by modality and forward the splits in different encoders?


desc_encoder = DescEncoder(...)
img_encoder = ImgEncoder(...)

batch = {

    'idx': torch.tensor([0, 1, 2, 3]),
    'modality': ['desc', 'img', 'desc', 'img'],
    'sample': torch.tensor([
               [   101,  35596, 174818, ...,  0,      0,      0],
               [   101, 322806, 347627,  ..., 0,      0,      0],
               [   101,  35596, 174818,  ..., 0,      0,      0],
               [   101,  35596, 174818,  ..., 0,      0,      0]
               ])
}

desc_sample = # only samples with desc modality
img_sample = # only samples with img modality

desc_rpr = desc_encoder(desc_sample)
img_rpr = img_encoder(img_sample)

Hi Celso, would this work?

mod = np.array(batch["modality"])
sample = batch["sample"]
desc_sample = sample[np.where(mod == "desc")]
img_sample = sample[np.where(mod == "img")]