Different model branches for different samples in batch

Hello, imagine I have a NN which has multiple output branches. I want to train the upstream layers jointly, while the output layers can be plugged in, based on a categorical variable. Let’s assume there are simply 2 FC output layers:

class MyModel(nn.Model):
      def __init__(self):
              ... # some layers 
              output_fc1 = nn.Linear(16,1)
              output_fc2 = nn.Linear(16,1)

Now, I have a batch of samples x, and an associated tensor output_branch_selector. The elements of output_branch_selector tell me which output layer should be plugged in for each sample in a batch, and they can either be 0 (for output_fc1) or 1 (for output_fc2).

Is there a way to do this in a single forward pass by assigning automatically an output branch based on output_branch_selector? E.g.

model = MyModel()
output = model(x, output_branch_selector)

Let me try to generalize it a bit for you, to a scalable number of branches.

class MyModel(nn.Module):
    "A model with variable number of branches."
    def __init__(self, in_features: int = 16, out_features: int = 1, num_branches: int = 5) -> None:
        """Init method.

        Args:
            in_features (int, optional): Number of input features. Defaults to 16.
            out_features (int, optional): Output features. Defaults to 1.
            num_branches (int, optional): Number of branches in model. Defaults to 5.
        """
        super().__init__()
        self.output_fcs = nn.ModuleList(
            [nn.Linear(in_features, out_features) for _ in range(num_branches)]
        )
        self.num_branches = num_branches
        self.in_features = in_features
        self.out_features = out_features

    def forward(self, X: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
        """Forward Method.

        Args:
            X (torch.Tensor): Input of shape (batch, in_features).
            S (torch.Tensor): Input of shape (batch,). Values in the range (0, num_branches - 1).

        Returns:
            torch.Tensor: Model predictions, of shape (batch, out_features).
        """
        return torch.cat([self.output_fcs[s](X[i, :]).unsqueeze(0) for i, s in enumerate(S)], axis=0)

Let’s do a simple test for this as well:

model = MyModel(16, 1, 5)
batch_size = 4
dummy_input = torch.randn(batch_size, 16)
dummy_selector = torch.randint(0, 5, (batch_size,))

print(dummy_selector)
# tensor([0, 3, 3, 4])

output = model(dummy_input, dummy_selector)
print(output, output.shape)
# tensor([[-0.8775],
#         [ 0.0129],
#         [-0.3767],
#         [-0.4781]], grad_fn=<CatBackward0>) torch.Size([4, 1])

So in the above example, no sample should have passed through branch 1 and 2.

Note the forward pass is kind of inefficient now: we lost any advantage of batching process as we’re doing B Linears if we have B samples.

A more efficient alternative would be to calculate which samples are in the same branch (e.g. with torch.where(S==i), given i is a branch number). However, after passing them through a branch together, you would have to restore their original order in the batch. Maybe you could try something like:

    def forward(self, X: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
        output = torch.zeros(X.shape[0], self.out_features)
        for i in range(self.num_branches):
            s = torch.where(S==i)[0]
            if s.shape[0] > 0:
                output[s] = self.output_fcs[i](X[s])
        return output

Test it out and see whichever works and suits your use case more : )

1 Like

Thanks a lot, it looks promising and I’ll test it tomorrow. In the forward you index X using variable i, but I don’t see where i comes from :slight_smile:

You are right:

return torch.cat([self.output_fcs[s](X[i, :]).unsqueeze(0) for s in S], axis=0)

should be

return torch.cat([self.output_fcs[s](X[i, :]).unsqueeze(0) for i, s in enumerate(S)], axis=0)

I edited into the original post too. Thanks!

1 Like

I think this solution works - can we imagine also a way without a for loop?

Probably. For that we would have to think like the second solution I suggested: compute the solution of each branch sequentially. Actually, we can try to exploit nn.Sequential for this, as sequential is really equivalent to iterating over a list of layers.

class LinearBranch(nn.Module):
    def __init__(self, in_features: int, out_features: int, branch_id: int):
        super().__init__()
        self.branch_id = branch_id
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, inputs):
        """inputs: Tuple of (X, S, output)."""
        s = torch.where(inputs[1]==self.branch_id)[0]
        if s.shape[0] > 0:
            inputs[2][s] = self.linear(inputs[0][s])
        return inputs

class MyModel(nn.Module):
    "A model with variable number of branches."
    def __init__(self, in_features: int = 16, out_features: int = 1, num_branches: int = 5) -> None:
        super().__init__()
        self.output_fcs = nn.Sequential(*[LinearBranch(in_features, out_features, id) for id in range(num_branches)])
        self.num_branches = num_branches
        self.in_features = in_features
        self.out_features = out_features

    def forward(self, X: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
        output = torch.zeros(X.shape[0], self.out_features)
        _, _, output = self.output_fcs((X, S, output))
        return output

We can put a wrapper around nn.Linear to handle the checking of the selector. Sequential would then pass X, S and an output tensor through the branches one by one, and each time some samples in the batch would be passed through a branch linear (or none at all), depending on the selector.

Just to check:

batch_size = 4
num_branches = 5

model = MyModel(16, 1, num_branches)

dummy_input = torch.randn(batch_size, 16)
dummy_selector = torch.randint(0, num_branches, (batch_size,))

print(dummy_selector)
# tensor([2, 1, 2, 4])

output = model(dummy_input, dummy_selector)
print(output, output.shape, output.requires_grad)
# tensor([[ 0.1223],
#         [-0.0652],
#         [ 0.2669],
#         [-0.2918]], grad_fn=<IndexPutBackward0>) torch.Size([4, 1]) True

Seems like it works, but do let me know how it goes.

1 Like

Even if we store all the weight matrices together, we would still have to 1) extract/slice appropriate branch weight, then 2) compute the matrix multiplication of that branch separately – as long as the distribution of the inputs to branches is not uniform (ie for a given batch, a branch can have 0 samples and will be inactive, and two branches a and b can have different number of assigned samples).

It would also not be possible with just a wrapper around linear; to store a 3D weight matrix, we would need to implement a custom layer with custom nn.Parameter() weight and bias, custom reset_parameters() method, etc. which is too much extra work.

And lastly, torch constructs its computation graph dynamically every pass any way, so there is actually no cost of change.

In the special case that the input has the same / uniform number of samples per branch (like say, In a batch of 10 inputs, it is guaranteed that there are 2 samples for each of 5 branches), it may be possible to compute the linear operation very efficiently as a single opearation. Something like this:

import torch

num_branch, samples_per_branch = 5, 2
in_features, out_features = 16, 1

X = torch.randn(num_branch, samples_per_branch, in_features)
# X -> (5, 2, 16)
W = torch.randn(num_branch, in_features, out_features)
# W -> (5, 16, 1)

# loop doing 5 matmuls
out1 = torch.stack([X[i] @ W[i] for i in range(num_branch)])

# single einsum reduction
out2 = torch.einsum("ijk,ikl->ijl", X, W)

print(out1.shape, out2.shape, torch.allclose(out1, out2))
# torch.Size([5, 2, 1]) torch.Size([5, 2, 1]) True

But this scenario is very different from the main idea of this question, and thus not very relevant.

1 Like

Thanks again @ID56 for the inspiration. It is actually to implement this efficient idea for arbitrary batch compositions using einsum:

import torch

num_branch = 5
batch_size = 8
in_features, out_features = 16, 1

X = torch.randn(batch_size, in_features)
W = torch.randn(num_branch, in_features, out_features)

# the vector containing the branch indices for each sample
X_branch_idx = np.random.randint(0, high=n_branches, size=batch_size)

# one may stack the W slices corresponding
Wstack = torch.stack([W[i,:,:] for i in X_branch_idx])

# and compute the output for the whole batch using einsum
out = torch.einsum("ij,ijk->ik", X, Wstack)

There’s still a for loop to build Wstack but that should be very little overhead.

1 Like

Glad I could be useful! Note that for large batches this may grow exponentially in memory. For example, if you have batch size = 512, you will need to allocate memory for a tensor of size (512, 16, 1) during each forward pass. Regardless, it is still a nice solution for when the batch (and model weights) isn’t too large.

1 Like