How to choose specific parameter for specific data

Suppose that I have a group of data to classify, each data is represented as (x,topic,y). x is the input of my model and y is the label. Topic is some relevant information of x and there are 3 classes of topics. I want different topic data share part of the parameters while other parameters are different. For example

class Net(nn.Module):
    def __init__(self,input_dim,hidden_dim,num_classes):
        self.W = nn.Linear(input_dim,hidden_dim)
        self.W1 = nn.Linear(hidden_dim,hidden_dim)
        self.W2 = nn.Linear(hidden_dim,hidden_dim)
        self.W3 = nn.Linear(hidden_dim,hidden_dim)

    def forward(self,x,topic):
        x = F.relu(self.W(x))
        if topic == 0:
            x = self.W1(x)
        elif topic == 1:
            x = self.W2(x)
        elif topic == 2:
            x = self.W3(x)
        return x

Since we would pack a group of data into a batch, above code only work when batch_size == 1. And I find it hard to extend to a bigger batch_size. I found two ways to solve this problem. One is to use loop in forward which seems unefficiently, the other is make a huge matrix Parameter according to topic tensor and do batch matrix multiplication in forward which is unefficient either. I am looking forward someone to help me solve this problem efficiently. I would appreciate it a lot!