How to mask a 3D tensor using mask of 2D

Suppose i have a 3D tensor A of size num_class * batch_size * Dim, and a mask B of size num_class * batch_size. Then , i want to mask A using the B tensor. However, I really dont know how to do it, can anyone help me ?
Thanks.

2 Likes

What do you mean by masking exactly?

Do you want to repeat B along all dimensions??
For that A = B.unsqueeze(-1).expand_as(A) would do!

Thanks for your reply. That’s is just what i neede.