Let me try to generalize it a bit for you, to a scalable number of branches.
class MyModel(nn.Module):
"A model with variable number of branches."
def __init__(self, in_features: int = 16, out_features: int = 1, num_branches: int = 5) -> None:
"""Init method.
Args:
in_features (int, optional): Number of input features. Defaults to 16.
out_features (int, optional): Output features. Defaults to 1.
num_branches (int, optional): Number of branches in model. Defaults to 5.
"""
super().__init__()
self.output_fcs = nn.ModuleList(
[nn.Linear(in_features, out_features) for _ in range(num_branches)]
)
self.num_branches = num_branches
self.in_features = in_features
self.out_features = out_features
def forward(self, X: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
"""Forward Method.
Args:
X (torch.Tensor): Input of shape (batch, in_features).
S (torch.Tensor): Input of shape (batch,). Values in the range (0, num_branches - 1).
Returns:
torch.Tensor: Model predictions, of shape (batch, out_features).
"""
return torch.cat([self.output_fcs[s](X[i, :]).unsqueeze(0) for i, s in enumerate(S)], axis=0)
Let’s do a simple test for this as well:
model = MyModel(16, 1, 5)
batch_size = 4
dummy_input = torch.randn(batch_size, 16)
dummy_selector = torch.randint(0, 5, (batch_size,))
print(dummy_selector)
# tensor([0, 3, 3, 4])
output = model(dummy_input, dummy_selector)
print(output, output.shape)
# tensor([[-0.8775],
# [ 0.0129],
# [-0.3767],
# [-0.4781]], grad_fn=<CatBackward0>) torch.Size([4, 1])
So in the above example, no sample should have passed through branch 1 and 2.
Note the forward pass is kind of inefficient now: we lost any advantage of batching process as we’re doing B Linears if we have B samples.
A more efficient alternative would be to calculate which samples are in the same branch (e.g. with torch.where(S==i), given i is a branch number). However, after passing them through a branch together, you would have to restore their original order in the batch. Maybe you could try something like:
def forward(self, X: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
output = torch.zeros(X.shape[0], self.out_features)
for i in range(self.num_branches):
s = torch.where(S==i)[0]
if s.shape[0] > 0:
output[s] = self.output_fcs[i](X[s])
return output
Test it out and see whichever works and suits your use case more : )