I have a tensor of images of size
(3600, 32, 32, 3) and I have a multi hot tensor [0, 1, 1, 0, …] of size
(3600, 1). I am looking to basically selecting images that correspond to a 1 in the multi hot tensor. I am trying to understand how to use torch.gather:
Gives me issues with dims and I can’t properly understand how to reshape them.
I’m not sure if I understand the use case correctly, but wouldn’t indexing the tensor with a mask work?
x = torch.randn(3600, 32, 32, 3)
idx = torch.randint(0, 2, (3600,))
mask = idx.bool()
out = x[mask]
# torch.Size([1765, 32, 32, 3])
The multi hot vector would be learned by the network. I believe doing
torch.gather would allow for gradient updates. Correct me if I am wrong.
gather would not detach the computation graph, but I don’t understand how the model is learning these indices, as it seems that the multi-hot tensor would already be detached or how is your model returning integer indices without using e.g.
Oh I see. Sorry for the confusing explanation. Let me rephrase it.
I have a Gumbel-Softmax network that outputs a vector
z basically logits. I want to take non-zero indices in this
z vector and select the same index location images from an image tensor that I have of 3600 images of 32x32x3 shape, tensor A
(3600, 32, 32, 3). But since my
z is of shape
(1, 3600), I have been unable to do
.gather. The output of gather should be something of the shape
(x, 32, 32, 3) where
x is the number of non-zero elements from
Later on, I want to optimize the Gumbel network to learn an optimal
Thanks for the follow up as this would match my understanding.
However, what it unclear is how you are getting the integer indices (zeros and ones) from the logit output (
z). Usually you would apply e.g.
torch.argmax on the logits to create the indices and could then use it to index
argmax would already detach the tensor so it wouldn’t matter if you are using my code snippet with a
gather, or any other operation.
z contains logits (values between
[-Inf, +Inf]) you won’t be able to use it as an index tensor.
My understanding was I would do
torch.nonzero(z) and then use that to do
gather. To get one hot encoded I can use the
hard=True flag here.
You are right about the
hard=True flag in
gumbel_softmax, but transforming the outputs via
out.bool() would still detach it since the result would be an integer/bool type tensor, which is why I’m unsure how the selection could work (in a differentiable way).
logits = torch.randn(20, requires_grad=True)
# Sample hard categorical using "Straight-through" trick:
out = F.gumbel_softmax(logits, tau=1, hard=True)
# <AddBackward0 object at 0x7f787c7d51f0>
res = torch.nonzero(out)
Could multiplying the output with
A work somehow?
I apologize I must have forgotten to mention. Using the
hard=True flag, I am okay with using that output to as my indices to sample my images. I won’t need to do
argmax, nonzero, etc...
Unfortunately, you would need to use integer types:
x = torch.randn(20)
# IndexError: tensors used as indices must be long, byte or bool tensors
The same applies for gather:
index (LongTensor) – the indices of elements to gather
And type casting it as such would detach the gradients I assume?
No, floating point dtypes are fine, only integer types would detach it.
I’m searching for some similar approaches in older posts to see if and how others have solved it.
I see. Would a
torch.mv(imagesTensor, out) work? Or something along those lines. Assuming
imagesTensor.shape = (3600, 32, 32, 3) and
out.shape = (3600, 1).
Yeah, exactly. I was thinking if something like this could work for you:
N = 10
x = torch.randn(N, 32, 32, 3)
out = torch.randint(0, 2, (N,)).float().requires_grad_(True)
res = x * out[:, None, None, None]
idx = res.nonzero()[:, 0].unique()
# 5 tensor(5., grad_fn=<SumBackward0>)
res = res[idx]
# torch.Size([5, 32, 32, 3])
# perform your operation here
res = res.mean()
# tensor([0., 0., 1., 1., 1., 1., 0., 0., 1., 0.], requires_grad=True)
# tensor([ 0.0000, 0.0000, -0.0044, 0.0005, 0.0032, -0.0044, 0.0000, 0.0000, 0.0046, 0.0000])
Thank you so much! This is exactly what I was looking for. I appreciate you helping out.