Mask indices of tensor from binaries (e.g. from Comparison Ops)

Let input is a tensor, dim is a dimension to mask, and mask is a ByteTensor. And, the following statement is true:

len(mask.size())==1 and input.size(dim)==mask.size(0)

I wrote this simple function for this task,

def masked_index(input, dim, mask):
	assert(len(mask.size())==1 and input.size(dim)==mask.size(0))
	sizes = input.size()
	for i in xrange(len(sizes)-1):
		mask = mask.unsqueeze(1)
	mask = mask.expand_as(input)
	return input[mask].view(-1, sizes[1], sizes[2])

, however, I don’t know if there is a better solution for this.

The gist is that sometimes we want to select indices on a dimension using a mask (ByteTensor), which usually comes from comparison ops (e.g. torch.eq()), instead of indices (LongTensor).

@Soumith_Chintala Any comment is welcome!


THIS FUNCTION HAS A BUG! PLEASE SEE THE BELOW COMMENT.

1 Like

The masked_index is not working properly if dim != 0. One of my colleagues suggest this function:

def masked_index(input, dim, mask):
	assert len(mask.size())==1 and input.size(dim)==mask.size(0), \
		'{}!=1 or {}!={}'.format(len(mask.size()), input.size(dim), mask.size(0))
	indices = torch.arange(0,mask.size(0))[mask].long()
	return input.index_select(dim, indices)

using torch.arange, we can get the corresponding indices easily. (ref. range vs. arange)

1 Like

Some little remarks:

  • you could use mask.dim() rather than len(mask.size()) I think :wink:
  • you could use mask.nonzero() rather than torch.arange(0,mask.size(0))[mask].long(), it gives the same result but is way more explicit !

All in all, I’m not sure this require a separate function, since you can simply write:

input.index_select(dim, mask.nonzero())

Hope this helps !

3 Likes

Thanks! torch.nonzero() is super useful!

I used thoses expression and encountered some problem.
my mask is a Variable, when I use mask.nonzero() it errors as Variable object has no attribute 'nonzero';
when I use mask.data.nonzero(),it shows
{RuntimeError}invalid argument 3: expecting vector of indices at /opt/conda/conda-bld/pytorch_1502006348621/work/torch/lib/THC/generic/THCTensorIndex.cu:405
so I used input.index_select(dim, mask.data.nonzero().suqeeze(1)), but It throwed another error:
{RuntimeError}save_for_backward can only save input or output tensors, but argument 0 doesn't satisfy this condition, Here my input is a variable .

Any idea what to do next? thanks in advance!

I have no idea, any comment, guys?

emm~ Here input is a Variable, I’m computing the loss of one network, I’d like to choose some rows to compute the loss based on the mask.

Indeed, this works well with a Tensor but Variable hasn’t got the nonzero method.
To do that, you need to get the tensor with mask.data, then apply the nonzero method, and then convert back to a Variable since it expects a Variable as input.

Also, it’s good to note that when you get the error below, it is often due to the fact that you input a Tensor in places where you should input a Variable:

{RuntimeError} save_for_backward can only save input or output tensors, but argument 0 doesn't satisfy this condition,

So you could do:

input.index_select(dim, Variable(mask.data.nonzero()))

Tell me if it’s ok ! (Though I didn’t test it, I guess it should be fine)

I just did that exactly as what you’ve said, it works!:grin:

But there is another question came with it, would this conversion: Variable -> Tensor -> Variable destroy the chain which conduct the gradient of mask to its creator? Is there some specific cases that the mask is part of the network and this operation should apply gradient to the mask?

Yes, indeed, it destroys the chain. I guess in many contexts this is not a problem.

I think a context where you want to optimize on the mask is more likely to be some kind of RL problem where you have a discrete action space.

Variable.nonzero is not implemented yet as discussed in the link below. However, if it was, I wonder how the backward would be implemented, since it outputs indices…

See an interesting discussion here too:

1 Like

So insightful you are!

with advanced indexing, we can use
input[mask.data.nonzero().squeeze(1),:]