Element-wise matrix multiplication along a certain dimension

Hi! I have an input of shape (b, x, y) and a weight matrix of shape (x, y, y), where b is the batch size and x is a dimension I would like to also broadcast across. That is, I’d like to do the equivalent of the following for loop:

def forward(x):
    # x.shape = (b, x, y)
    # self.W.shape = (x, y, y)
    out = torch.zeros((b, x, y))
    for i in x.shape[0]:
        for j in x.shape[1]:
            out[i, j] = torch.matmul(self.W[j], x[i, j])
    return out.sum(1) # Returns shape (b, y), summing over x

As you can see, the weight matrix is being batch multiplied on dimension zero, and element-wise multiplied on dimension 1. This would be somewhat similar to a convolution operation, where the kernel is the size of the entire input. I can’t figure out how to implement this, can anyone help?