Change tensor values by index lower than some value

Hi all,

I’m new and have a simple problem I cannot solve. I have a (2,n) tensor of logits and I would like to do a topk selection. I’ve been reading this link to implement what I want. The code works when the logits are a 1d but not when they’re (2,n).
I can’t figure out how to modify the code accordingly.

There is a very similar post on this forum using the cumsum but is not what I want. My minimal code is:

logits = torch.tensor([[ 0.0333,  0.5174,  0.4797],[-0.0644, -0.5815, -0.7552]])
print(logits)
topk_values, topk_indices = torch.topk(logits, 2)
last = topk_values[:,-1].unsqueeze(dim=-1)
print(last)
print(topk_indices)
logits[logits < topk_values[-1]] = float('-inf')

but I’m getting the following error:

RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

Could you please help me to understand what’s going on and how to solve it?

Thank you!

Hi isg!

The problem is that logits, with trailing dimension 3, and topk_values[-1]], with
trailing dimension 2, are not broadcastable.

I’m guessing what you want is:

logits[logits < last] = float('-inf')

In this case, unsqueeze() has been used to add a trailing singleton (size = 1) dimension
to last, which is broadcastable to the size-3 trailing dimension of logits.

Best.

K. Frank

Thank you @KFrank!

Correctly, my problem was assuming some broadcasting that was impossible. You pointed me in the right direction (without giving me the solution, which forced me to read more about broadcasting and learn more, thanks again).

At the end, the solution that I have implemented is a class to do what I wanted.

class TopK:

    def __init__(self, k: int) -> None:
        
        self.k = k

    def predict(self, logits) -> torch.tensor:

        # We reshape the logits tensor removing the leading dimension: "batch"
        logits = logits.view(logits.shape[0]*logits.shape[1],-1)

        # We compute the topk values
        topk_values, topk_indices = torch.topk(logits, self.k)

        # We compute the minimum value for each row
        min_values = torch.min(topk_values, dim=-1).values.view(logits.shape[0],1)

        # We renormalize the probabilities setting any logits smaller than
        # the topk to "-inf" first, and then computing the softmax row-wise
        logits[logits < min_values] = float('-inf')
        probs = logits.softmax(dim=-1)

        # Finally, we use the multinomial Pytorch method to make a random pick across the
        # topk using the re-normalized probabilities.
        results = torch.multinomial(probs, num_samples=1).flatten()

        return results