Is there anyway to make this nested for loop faster with tensor semantics?
def f(sequences, log_probs):
batch_size, seq_len = sequences.size()
# sequences = sequences.view(-1).long()
# seq_log_probs = log_probs[:, :, sequences]
seq_log_probs = torch.zeros_like(sequences).float()
for b in range(batch_size):
for t in range(seq_len):
idx = sequences[b, t].item()
seq_log_probs[b, t] = log_probs[b, t, idx].item()
return seq_log_probs.sum()