How to calculate multiple linear layer in one pass

I’ve been experimenting with a transformer-like architecture where one input vector x is passed through h different linear layers, one for each head, and let’s call them y_1, y_2, … y_h. I want to pass the results for each head (y_1,y_2,…y_h) into an additional linear layer independently of each other, something like y_1->lin_1->y’1, y_2->lin_2->y’_2,…

The first step, namely converting x to y_1, y_2, …y_h is simple. I can just use a linear layer of hidden_size as input size, and h*hidden_size as output size, and divide evenly the output into each head.

But the second step is hard to implement with GPU parallelization in mind. Currently, I’m iterating through the number of heads and doing each of the linear layers one by one. The pseudocode is:

# lins, y, y_prime are lists
for head in h:
    y_prime[head] = lins[head](y[head])

Is there a way supported by pytorch that can calculate the results for all linear layers simultaneously? Thanks.

If I am understanding the problem correctly, this should be implementable using a single matrix multiply.

Assume that y is a tensor of shape (n_head, hidden_size), and that each linear layer takes in hidden_size in-dims and returns output_size out-dims. This means that each y'_i has shape (output_size,).

Instead of using n_head number of linear layers, we can store a single large weight of shape (n_head, output_size, hidden_size) and use do a batched matrix multiply between the weight and the y:

weight = torch.randn(n_head, output_size, hidden_size, requires_grad=True)
y = torch.randn(n_head, hidden_size)

y_prime = torch.matmul(y.unsqueeze(-1), weight).squeeze(-1)

This gets a little weirder but can still be extended to when you have a batch dimension in addition to your head dimension and if you care about the bias in a linear layer.

If you’re adventurous and on a nightly build of PyTorch, torch.vmap can do this without the dimension squeezing business:

from torch._vmap_internals import vmap
weight = torch.randn(n_head, output_size, hidden_size, requires_grad=True)
y = torch.randn(n_head, hidden_size)
y_prime = vmap(torch.matmul)(weight, y)

Thanks a lot for this elegant solution. I have one more question about batched matul: If instead of (n_head, hidden), the y_i has shape (batch, windows, n_head, hidden), do I need to manually unsqueeze the first two dimensions of the weight tensor and repeat them to match the size of y_i?

torch.matmul does broadcasting for batch dimensions (it does batched matrix-matrix multiply), so no, you don’t need to manually unsqueeze the first two dimensions of the weight tensor. However, we do have to be careful to make sure that the n_head dimension lines up correctly. So if y has shape (batch, windows, n_head, hidden) and weight has shape (n_head, output, hidden), then we would want to line up the n_head dimensions by:

  • unsqueezing y to have shape (batch, windows, n_head, hidden, 1)
  • doing matmul(weight, unsqueezed_y) to get an output of shape (batch, windows, n_head, output, 1)
  • squeezing away the last dimension of size 1.