Different weight for samples in nn.Linear


I have a model that would like to apply different weight per sample in a batch. For example, if the batch size is 3, in_feature is 108, out_feature is 2, then I have the batched input X (shape: (3, 108)), and corresponding weight W (shape: (3, 2, 108)).

Originally, if the weight is the same for every sample in a batch, the weight W shape is (2, 108), which can be easily plugged in nn.functional.linear(X, W).

So far I iterate over each sample for computing the linear layer, i.e. nn.functional.linear(X[i], W[i]) for i = [0, batch_size). However, this is very inefficient, so I am wondering that is there any utility function that would be helpful in this case?