Something like Batch Mat Mul bmm

I want to implement operation like this


in which each square represents a matrix,line between square means matrix multiplication

Assume I have 4 matrix parameters and each of my inputs have 4 matrix(or vector),theses four parameter matrix and input matrix are multiplicated respectively for n batches.

I implement this operation via bmm but its too inefficient, I just stack my parameters to make its batched like my inputs, like this

W = [nn.Linear(10,10,bias=False) for _ in range(4)]
batched_W = [W[i%4].weight.unsqueeze(0) for i in range(4*batch_size)]
batched_W = torch.cat(batched_W,dim=0).cuda()
out = torch.bmm(batched_W,x) # x is my input

since I use GPU for training and .weight operation seems put parameters back to cpu, so I must add .cuda(), I think it’s too inefficient, but I can’t come up with other idea.