Suppose that I have a group of data to classify, each data is represented as (x,topic,y). x is the input of my model and y is the label. Topic is some relevant information of x and there are 3 classes of topics. I want different topic data share part of the parameters while other parameters are different. For example

```
class Net(nn.Module):
def __init__(self,input_dim,hidden_dim,num_classes):
self.W = nn.Linear(input_dim,hidden_dim)
self.W1 = nn.Linear(hidden_dim,hidden_dim)
self.W2 = nn.Linear(hidden_dim,hidden_dim)
self.W3 = nn.Linear(hidden_dim,hidden_dim)
def forward(self,x,topic):
x = F.relu(self.W(x))
if topic == 0:
x = self.W1(x)
elif topic == 1:
x = self.W2(x)
elif topic == 2:
x = self.W3(x)
return x
```

Since we would pack a group of data into a batch, above code only work when batch_size == 1. And I find it hard to extend to a bigger batch_size. I found two ways to solve this problem. One is to use loop in `forward`

which seems unefficiently, the other is make a huge matrix Parameter according to topic tensor and do batch matrix multiplication in `forward`

which is unefficient either. I am looking forward someone to help me solve this problem efficiently. I would appreciate it a lot!