Tensor Multiplication Question

I’m not sure if it’s possible, but I’m trying to multiply a 2D tensor by a 3D tensor & get out another 2D tensor.

Code: r = input.matmul(newW)
where input has dimensions [128, 512], newW has dimensions [512, 10, 128] & I am trying to get out an [128, 10] .

The idea is to do [x,y] * [y, z, x] = [x, z]. The goal is to do:
[1, y].matmul([y, z]), with x [1,y] & [y,z] tensors, and then concatenate the x [1, z] results.
I accomplished this with a for-loop, but I’m hoping to get rid of the loop to speed up computation.

However, my 2D*3D matmul line gives me this error: RuntimeError: invalid argument 6: wrong matrix size at c:\a\w\1\s\tmp_conda_3.6_105809\conda\conda-bld\pytorch_1544094150554\work\aten\src\thc\generic/THCTensorMathBlas.cu:492

My linear algebra knowledge goes up to 2D * 2D. Is 2D * 3D tensor multiplication even possible?
If so, if someone could explain how to get this to work in Pytorch, it would be greatly appreciated.

Thanks!

Update:
I believe I resolved the issue by expanding the 2D to a 3D, then doing pairwise multiplication, and finally summing across a dimension (akin to how traditional matrix multiplication works).
Not entirely confident this accomplishes what I intended, but I believe it does.
Working Code:

W1= weight[:,:,0].expand(input.size(0), -1, -1)  #Results in batch x output x input
W2= weight[:,:,1].expand(input.size(0), -1, -1)  #same
inp = torch.sigmoid(input.expand(weight.size(0), -1, -1)).permute(1, 0, 2)  #same
newW = inp * W1  #same and then pairwise multiplication
newW += (1-inp) * W2 #same and then pairwise multiplication
input = input.expand(weight.size(0), -1, -1).permute(1, 0, 2)  #Same
ret = torch.sum((input*newW), 2) #Same, then pairwise multiplication, 
#and sums across input dimension to get out batch x output