# Using torch.gather properly

I have a tensor of images of size `(3600, 32, 32, 3)` and I have a multi hot tensor [0, 1, 1, 0, …] of size `(3600, 1)`. I am looking to basically selecting images that correspond to a 1 in the multi hot tensor. I am trying to understand how to use torch.gather:

`tensorA.gather(0, tensorB)`

Gives me issues with dims and I can’t properly understand how to reshape them.

I’m not sure if I understand the use case correctly, but wouldn’t indexing the tensor with a mask work?

``````x = torch.randn(3600, 32, 32, 3)
idx = torch.randint(0, 2, (3600,))
print(idx)

print(out.shape)
# torch.Size([1765, 32, 32, 3])
print(idx.sum())
# tensor(1765)
``````
1 Like

The multi hot vector would be learned by the network. I believe doing `torch.gather` would allow for gradient updates. Correct me if I am wrong.

Yes, `gather` would not detach the computation graph, but I don’t understand how the model is learning these indices, as it seems that the multi-hot tensor would already be detached or how is your model returning integer indices without using e.g. `argmax`?

Oh I see. Sorry for the confusing explanation. Let me rephrase it.

I have a Gumbel-Softmax network that outputs a vector `z` basically logits. I want to take non-zero indices in this `z` vector and select the same index location images from an image tensor that I have of 3600 images of 32x32x3 shape, tensor A `(3600, 32, 32, 3)`. But since my `z` is of shape `(1, 3600)`, I have been unable to do `.gather`. The output of gather should be something of the shape `(x, 32, 32, 3)` where `x` is the number of non-zero elements from `z`.

Later on, I want to optimize the Gumbel network to learn an optimal `z`.

Thanks for the follow up as this would match my understanding.
However, what it unclear is how you are getting the integer indices (zeros and ones) from the logit output (`z`). Usually you would apply e.g. `torch.argmax` on the logits to create the indices and could then use it to index `A`. However, `argmax` would already detach the tensor so it wouldn’t matter if you are using my code snippet with a `mask`, `gather`, or any other operation.
Assuming `z` contains logits (values between `[-Inf, +Inf]`) you won’t be able to use it as an index tensor.

My understanding was I would do `torch.nonzero(z)` and then use that to do `gather`. To get one hot encoded I can use the `hard=True` flag here.

You are right about the `hard=True` flag in `gumbel_softmax`, but transforming the outputs via `nonzero()`, `out.long()`, or `out.bool()` would still detach it since the result would be an integer/bool type tensor, which is why I’m unsure how the selection could work (in a differentiable way).

``````logits = torch.randn(20, requires_grad=True)
# Sample hard categorical using "Straight-through" trick:
out = F.gumbel_softmax(logits, tau=1, hard=True)

res = torch.nonzero(out)
# None

# None

# None
``````

Could multiplying the output with `A` work somehow?

I apologize I must have forgotten to mention. Using the `hard=True` flag, I am okay with using that output to as my indices to sample my images. I won’t need to do `argmax, nonzero, etc...`

Unfortunately, you would need to use integer types:

``````x = torch.randn(20)
x[out]
# IndexError: tensors used as indices must be long, byte or bool tensors
``````

The same applies for gather:

index (LongTensor) – the indices of elements to gather

And type casting it as such would detach the gradients I assume?

No, floating point dtypes are fine, only integer types would detach it.
I’m searching for some similar approaches in older posts to see if and how others have solved it.

I see. Would a `torch.mv(imagesTensor, out)` work? Or something along those lines. Assuming `imagesTensor.shape = (3600, 32, 32, 3)` and `out.shape = (3600, 1)`.

Yeah, exactly. I was thinking if something like this could work for you:

``````N = 10
x = torch.randn(N, 32, 32, 3)

res = x * out[:, None, None, None]

idx = res.nonzero()[:, 0].unique()
print(len(idx), out.sum())

res = res[idx]
print(res.shape)
# torch.Size([5, 32, 32, 3])

res = res.mean()

# backward
res.backward()

print(out)
# tensor([0., 0., 1., 1., 1., 1., 0., 0., 1., 0.], requires_grad=True)