Argmax from multiple tensors without replacement

Current problem: I have a batch of sequences of equal length, and, for each item in the sequence, a probability distribution over all items.
I.e. I have a tensor of shape (batch_size, seq_length, num_items).

For each item in the sequence, I want to compute the argmax using these distributions.
This is trivial with no further restrictions: torch.argmax supports batched computation and a dim argument.
I would, however, like to sample sequences with no duplicates.
A simple heuristic would be: at item j, set the probability of items i < j = 0.
Implementing it using for loops:

batch_samples = []
for i in range(batch_size):
  # seq_dists dimension (seq_length, num_items)
  seq_dists = my_batch[i, :, :]
  seq = []
  for j in range(seq_length):
    seq_dist = seq_dists[j]
    # set prob. of previous samples to 0
    seq_dist = seq_dist.scatter_(dim=-1, index=torch.LongTensor(seq), src=torch.tensor(0))
    item = torch.argmax(seq_dist)


batch_samples = torch.stack(batch_samples)

However, as expected, this process is incredibly slow.
Is there anything I can do achieve the same result in a more efficient way?

Any ideas on how this can be solved?

Bump… no way to mitigate this problem?