How can I do the following operation efficiently?

Hi, now, I have a data matrix X, assuming that X = [1,2,3;4,5,6;7,8,9], then, I also have an index binary matrix Y, assuming that Y = [0,1,0;1;0;1,0,0,0], I want to know that how can I efficiently get torch.sum(X*Y) by utilizing the characteristics of binary matrix Y.