I want to use a DeIT model with different tokenizers based on the input type. For example, if my dataset has four types of data, I want to use four different tokenizers, one for each type but one single encoder block. Each batch will have a mix of all types of data. If a sample is of type 0, the forward pass should look like tokenizer0 → encoder → classifier, and so on for different types.
Code:
class DeitWithClassifier(nn.Module):
def __init__(self):
super(DeitWithClassifier, self).__init__()
self.tokenizers = {
0: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,
1: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,
2: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,
3: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,
}
deit_model = DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224")
self.encoder = deit_model.encoder
self.layer_norm = deit_model.layernorm
self.pooler = deit_model.pooler
self.projector = nn.Linear(192, 512)
self.classifier = nn.Linear(512, 1)
def forward(self, batch):
output, embedding = [], []
for t in range(len(batch['type'])):
tokenizer = self.tokenizers[batch['type'][t]]
tokenized = tokenizer(batch['pixel_values'][t].unsqueeze(0).to(device))
encoded = self.encoder(tokenized).last_hidden_state
normalized = self.layer_norm(encoded)
emb = self.projector(normalized[0][0])
logits = self.classifier(emb)
output.append(logits)
embedding.append(emb)
output = torch.stack(output)
embedding = torch.stack(embedding)
return output, embedding
Each batch has ‘type’ which indicates the type of the samples and ‘pixel_values’.
Though this code works, I am not sure if the backpropogation is happening like it is supposed to.
Is using a list in the forward pass ok?
Does it mess with the gradients/backpropogation? How else can I implement this?