I’m trying to implement the shared log_softmax found in this architecture for question-answering: http://aclweb.org/anthology/P18-1078, in section 3.3. This is the relevant formula:
How would I go about implementing it in PyTorch in away that’s stable and differentiable? Is there any existing function that I can leverage?
Hi, your objective function is differentiable, but you need to code that computation like this:
from torch import exp, log
def loss(d, z, sa, gb, s, g):
S = torch.empty(1)
for i in range(n):
for j in range(n):
S += exp(s[i]+g[j])
return - log(((1-d)*exp(z) + d*exp(sa + gb))/(exp(z) + S))
My current implementation is fairly similar to that, but I found that it dramatically slows down the time per batch to calculate all the sums individually during training. I am also concerned about stability when implementing softmax manually (is that an unnecessary concern?). Any suggestions on ways to optimize?
The for loop could probably be reduced to this:
n = 10
s = torch.arange(n, dtype=torch.float)
g = torch.arange(n, dtype=torch.float)
torch.exp(s + g.view(n, 1)).sum([0, 1])
Could you check, if this gives the right answer? It should be faster than the for loops.