Using torch.gather properly

I have a tensor of images of size (3600, 32, 32, 3) and I have a multi hot tensor [0, 1, 1, 0, …] of size (3600, 1). I am looking to basically selecting images that correspond to a 1 in the multi hot tensor. I am trying to understand how to use torch.gather:

tensorA.gather(0, tensorB)

Gives me issues with dims and I can’t properly understand how to reshape them.

I’m not sure if I understand the use case correctly, but wouldn’t indexing the tensor with a mask work?

x = torch.randn(3600, 32, 32, 3)
idx = torch.randint(0, 2, (3600,))
print(idx)
mask = idx.bool()

out = x[mask]
print(out.shape)
# torch.Size([1765, 32, 32, 3])
print(idx.sum())
# tensor(1765)
1 Like

The multi hot vector would be learned by the network. I believe doing torch.gather would allow for gradient updates. Correct me if I am wrong.

Yes, gather would not detach the computation graph, but I don’t understand how the model is learning these indices, as it seems that the multi-hot tensor would already be detached or how is your model returning integer indices without using e.g. argmax?

Oh I see. Sorry for the confusing explanation. Let me rephrase it.

I have a Gumbel-Softmax network that outputs a vector z basically logits. I want to take non-zero indices in this z vector and select the same index location images from an image tensor that I have of 3600 images of 32x32x3 shape, tensor A (3600, 32, 32, 3). But since my z is of shape (1, 3600), I have been unable to do .gather. The output of gather should be something of the shape (x, 32, 32, 3) where x is the number of non-zero elements from z.

Later on, I want to optimize the Gumbel network to learn an optimal z.

Thanks for the follow up as this would match my understanding.
However, what it unclear is how you are getting the integer indices (zeros and ones) from the logit output (z). Usually you would apply e.g. torch.argmax on the logits to create the indices and could then use it to index A. However, argmax would already detach the tensor so it wouldn’t matter if you are using my code snippet with a mask, gather, or any other operation.
Assuming z contains logits (values between [-Inf, +Inf]) you won’t be able to use it as an index tensor.

My understanding was I would do torch.nonzero(z) and then use that to do gather. To get one hot encoded I can use the hard=True flag here.

You are right about the hard=True flag in gumbel_softmax, but transforming the outputs via nonzero(), out.long(), or out.bool() would still detach it since the result would be an integer/bool type tensor, which is why I’m unsure how the selection could work (in a differentiable way).

logits = torch.randn(20, requires_grad=True)
# Sample hard categorical using "Straight-through" trick:
out = F.gumbel_softmax(logits, tau=1, hard=True)
print(out.grad_fn)
# <AddBackward0 object at 0x7f787c7d51f0>

res = torch.nonzero(out)
print(res.grad_fn)
# None

print(res.long().grad_fn)
# None

print(res.bool().grad_fn)
# None

Could multiplying the output with A work somehow?

I apologize I must have forgotten to mention. Using the hard=True flag, I am okay with using that output to as my indices to sample my images. I won’t need to do argmax, nonzero, etc...

Unfortunately, you would need to use integer types:

x = torch.randn(20)
x[out]
# IndexError: tensors used as indices must be long, byte or bool tensors

The same applies for gather:

index (LongTensor) – the indices of elements to gather

And type casting it as such would detach the gradients I assume?

No, floating point dtypes are fine, only integer types would detach it.
I’m searching for some similar approaches in older posts to see if and how others have solved it.

I see. Would a torch.mv(imagesTensor, out) work? Or something along those lines. Assuming imagesTensor.shape = (3600, 32, 32, 3) and out.shape = (3600, 1).

Yeah, exactly. I was thinking if something like this could work for you:

N = 10
x = torch.randn(N, 32, 32, 3)
out = torch.randint(0, 2, (N,)).float().requires_grad_(True)

res = x * out[:, None, None, None]

idx = res.nonzero()[:, 0].unique()
print(len(idx), out.sum())
# 5 tensor(5., grad_fn=<SumBackward0>)

res = res[idx]
print(res.shape)
# torch.Size([5, 32, 32, 3])

# perform your operation here
res = res.mean()

# backward
res.backward()

print(out)
# tensor([0., 0., 1., 1., 1., 1., 0., 0., 1., 0.], requires_grad=True)
print(out.grad)
# tensor([ 0.0000,  0.0000, -0.0044,  0.0005,  0.0032, -0.0044,  0.0000,  0.0000, 0.0046,  0.0000])
1 Like

Thank you so much! This is exactly what I was looking for. I appreciate you helping out.

1 Like