Element-wise batch multiplication

I have tensors X and Y where X has size(B,N,N) and Y has size (N,N). I’d like to element-wise multiply Y to every batch of X without replicating Y to be of size (B,N,N), nor building a for loop. Any tips?

torch.add supports broadcasting.
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

a = torch.rand(2, 3, 3)
b = torch.rand(3, 3)
c = torch.add(a, b) # (2, 3, 3)