How to create an efficient conditional layer in pytorch?

I have a resnet50 model that outputs a class prediction (1, 2 or 3). Based on the output of the classifier, I want to make another prediction that selects the next layer/model based on the previous model output.

This is what I have so far.

import torch

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model1.weight)

        self.model2 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model2.weight)

        self.model3 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model3.weight)

    def forward(self, x):
        
        # Get batch_size
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # Loop over every value in batch
        for i in range(batch_size):
            value = x[:, i]
            if value == 1:
                output[i] = self.model1(value)
            elif value == 2:
                output[i] = self.model2(value)
            else:
                output[i] = self.model3(value)

        return output
model = SimpleModel()

output = model(torch.tensor([[1,2,3]], dtype=torch.float32))
output

tensor([[1.],
        [2.],
        [3.]], grad_fn=<CopySlices>)

My concern is that I am only computing one forward pass on each iteration of the loop which seems very inefficient. What happens if I increase the batch size to 64? Will the forward pass be computed in parallel?

Any thoughts/ideas would be appreciated.

It depends on the size of the model, but some simple initial solutions you could try are:

  1. simply running all three models for each input and only selecting the chosen input (e.g., via gather torch.gather — PyTorch 2.0 documentation)
  2. accumulating input for each model until there is enough to run an entire batch (for whatever batch size you choose)

Beyond that I think efficiently data-dependent control flow at a per-input basis is a tricky problem.
There also has been interesting work in leveraging sparse primitives for Mixture-of-Experts training (MoE) for LLM https://arxiv.org/pdf/2211.15841.pdf, so I wonder if that could be applicable here.

If you’re okay with e.g., “soft-conditioning,” then this paper might be interesting: [1904.04971] CondConv: Conditionally Parameterized Convolutions for Efficient Inference

1 Like

Hey @eqy, thanks for the reply.

I really like the idea of the second solution. Even though it would require multiple iterations over the dataset, I think it would be the fastest solution. I will try this out!