Applying a shared softmax

I’m trying to implement the shared log_softmax found in this architecture for question-answering: http://aclweb.org/anthology/P18-1078, in section 3.3. This is the relevant formula:

How would I go about implementing it in PyTorch in away that’s stable and differentiable? Is there any existing function that I can leverage?

Hi, your objective function is differentiable, but you need to code that computation like this:

from torch import exp, log

def loss(d, z, sa, gb, s, g):
    S = torch.empty(1)
    for i in range(n):
        for j in range(n):
            S += exp(s[i]+g[j])
    
    return - log(((1-d)*exp(z) + d*exp(sa + gb))/(exp(z) + S))

My current implementation is fairly similar to that, but I found that it dramatically slows down the time per batch to calculate all the sums individually during training. I am also concerned about stability when implementing softmax manually (is that an unnecessary concern?). Any suggestions on ways to optimize?

The for loop could probably be reduced to this:

n = 10
s = torch.arange(n, dtype=torch.float)
g = torch.arange(n, dtype=torch.float)
torch.exp(s + g.view(n, 1)).sum([0, 1])

Could you check, if this gives the right answer? It should be faster than the for loops.