I’m looking for the most pythonic way to select all entries of a Tensor whose indices full fill a certain condition. I am aware of torch.where(), but the condition here acts on the value of each entry, not its indices.
E.g. I would want to select all entries where the indices i and j along the first two dimensions fulfill i // k > j // k. What is the best way to achieve this?
Mmmh, that doesn’t really answer my question. I’m aware that I can use masked_select() or where() to select entries based on some mask. And you’re right that for the special case of k = 1 I could use tril() or triu() to get a mask in very few lines of code.
What I’m asking though if there is a more general way to either select entries directly, or at least construct a boolean mask elegantly for an arbitrary function of the indices i, j, ... along a tensor’s dimensions.