The code snippet is below.
import torch
data = torch.randn((10, 2))
data_norm = torch.linalg.vector_norm(data, dim=1)
mask1 = (data_norm <= 1).unsqueeze(dim=1) # mask1.shape=(10,1)
small_data = data[mask1]
After running, the following error is reported.
Traceback (most recent call last):
File "practice_14.py", line 8, in <module>
small_data = data[mask1]
IndexError: The shape of the mask [10, 1] at index 1 does not match the shape of the indexed tensor [10, 2] at index 1
But i remember mask select have broadcasting mechanism.
Can someone explain that?
(By the way, if i modify mask1 = (data_norm <= 1).unsqueeze(dim=1)
to mask1 = (data_norm <= 1)
, i.e. mask.shape=(10,)
, above code can run correctly.)