Compute forward on every variation of my inputs, to get my backward


I am implementing a new energy-based model. My input is sequential and categorical (with a small number of categories). Right now my simple forward map every sequence to a real number. Unfortunately, I want to compute something complicated, a pseudolikelihood.
To do so I need for every sequence to compute the forward of every sequence which are different by one element from my input sequence, then use these values to renormalize with a softmax for every position.

Mainly my question is the following: How to create a function which for every element in my batch will create a temporary batch of all the variation of my sequence so that I can apply my energy computation + renormalization. It is a kind of “sub_batch” for every element in the batch.

Here is a simple numpy implementation of how I would do it in basic numpy with for loops:

energy_variation = np.zeros(batch_size, sequence_length, number_categories)
for seq_index in range(batch_size):
    seq = batch[seq_index]
    for i in range(sequence_length):
        for cat in range(number_categories):
            temposeq = copy(seq)
            temposeq[i] = cat
            energy_variation[seq_index, i , cat] = self.forward(temposeq)
        energy_variation[seq_index, i , :] = softmax(energy_variation[seq_index, i , :])

Btw my loss is computed on the results of some value of this energy_variation matrix, meaning it needs to backpropagate through all the forward pass.

Mainly how can I do something efficient in Pytorch for this problem.