Batching training Denoising Autoencoder

I have this code for training a denoising autoencoder that uses an LSTM for encoder and decoder and operates on names

def denoise_train(x: DataLoader):
    loss = 0.
    noisy_x = list(map(lambda s: noise_name(s), x))

    rnn_x = to_rnn_tensor(x, DECODER_COUNT)
    rnn_noisy_x = to_rnn_tensor(noisy_x, ENCODER_COUNT)

    encoder_hidden = encoder.init_hidden(batch_size=BATCH_SZ)

    for i in range(rnn_noisy_x.shape[0]):
        _, encoder_hidden = encoder(rnn_noisy_x[i].unsqueeze(0), encoder_hidden)

    decoder_input = strings_to_tensor([SOS] * BATCH_SZ)

    decoder_hidden = encoder_hidden

    name = ''

    for i in range(rnn_x.shape[0]):

        decoder_probs, decoder_hidden = decoder(decoder_input, decoder_hidden)

        _, nonzero_indexes = rnn_x[i].topk(1)

        # TODO!!! Need to fix rest of code for batch

        best_index = torch.argmax(decoder_probs, dim=2).item()

        loss += criterion(decoder_probs[0], nonzero_indexes[0])

        name += ALL_CHARS[best_index]

        decoder_input = torch.zeros(1, 1, LETTERS_COUNT)

        decoder_input[0, 0, best_index] = 1.

    return name, noisy_x, loss.item()

The x that gets passed in is the next iteration of an iter(DataLoader). The main thing I’m trying to do is get the argmax for all the decoder_probs, which is of size name length x batch size x output length. So I need best_index to be the argmax for all entries in the batch and decoder_input should be a 1xbatch size x output where all the best chars = 1.