Forward Pass through different modules based on input type

I want to use a DeIT model with different tokenizers based on the input type. For example, if my dataset has four types of data, I want to use four different tokenizers, one for each type but one single encoder block. Each batch will have a mix of all types of data. If a sample is of type 0, the forward pass should look like tokenizer0 → encoder → classifier, and so on for different types.
Code:

class DeitWithClassifier(nn.Module):
    def __init__(self):
        super(DeitWithClassifier, self).__init__()
        self.tokenizers = {
            0: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,
            1: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,  
            2: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,  
            3: DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224").embeddings,  
        }
        deit_model = DeiTModel.from_pretrained("facebook/deit-tiny-distilled-patch16-224")
        self.encoder = deit_model.encoder
        self.layer_norm = deit_model.layernorm
        self.pooler = deit_model.pooler
        self.projector = nn.Linear(192, 512)
        self.classifier = nn.Linear(512, 1)

    def forward(self, batch):
        output, embedding = [], []

        for t in range(len(batch['type'])):
            tokenizer = self.tokenizers[batch['type'][t]]

            tokenized = tokenizer(batch['pixel_values'][t].unsqueeze(0).to(device))
            encoded = self.encoder(tokenized).last_hidden_state
            normalized = self.layer_norm(encoded)

            emb = self.projector(normalized[0][0])
            logits = self.classifier(emb)

            output.append(logits)
            embedding.append(emb)

        output = torch.stack(output)
        embedding = torch.stack(embedding)

        return output, embedding

Each batch has ‘type’ which indicates the type of the samples and ‘pixel_values’.
Though this code works, I am not sure if the backpropogation is happening like it is supposed to.
Is using a list in the forward pass ok?
Does it mess with the gradients/backpropogation? How else can I implement this?

Yes, using a list to append activations is fine and you won’t break the computation graph if you create a single tensor via torch.cat or torch.stack from this list. You should not create a new tensor via torch.tensor(list_entries) since this will detach the newly created tensor from the computation graph.

1 Like