I want to implement operation like this
in which each square represents a matrix,line between square means matrix multiplication
Assume I have 4 matrix parameters and each of my inputs have 4 matrix(or vector),theses four parameter matrix and input matrix are multiplicated respectively for n batches.
I implement this operation via bmm
but its too inefficient, I just stack my parameters to make its batched like my inputs, like this
W = [nn.Linear(10,10,bias=False) for _ in range(4)]
batched_W = [W[i%4].weight.unsqueeze(0) for i in range(4*batch_size)]
batched_W = torch.cat(batched_W,dim=0).cuda()
out = torch.bmm(batched_W,x) # x is my input
since I use GPU for training and .weight
operation seems put parameters back to cpu, so I must add .cuda()
, I think it’s too inefficient, but I can’t come up with other idea.