I want to implement operation like this

in which each square represents a matrix,line between square means matrix multiplication

Assume I have 4 matrix parameters and each of my inputs have 4 matrix(or vector),theses four parameter matrix and input matrix are multiplicated respectively for n batches.

I implement this operation via `bmm`

but its too inefficient, I just stack my parameters to make its batched like my inputs, like this

```
W = [nn.Linear(10,10,bias=False) for _ in range(4)]
batched_W = [W[i%4].weight.unsqueeze(0) for i in range(4*batch_size)]
batched_W = torch.cat(batched_W,dim=0).cuda()
out = torch.bmm(batched_W,x) # x is my input
```

since I use GPU for training and `.weight`

operation seems put parameters back to cpu, so I must add `.cuda()`

, I think it’s too inefficient, but I can’t come up with other idea.