I’m trying to implement the shared log_softmax found in this architecture for question-answering: http://aclweb.org/anthology/P18-1078, in section 3.3. This is the relevant formula:
Hi, your objective function is differentiable, but you need to code that computation like this:
from torch import exp, log
def loss(d, z, sa, gb, s, g):
S = torch.empty(1)
for i in range(n):
for j in range(n):
S += exp(s[i]+g[j])
return - log(((1-d)*exp(z) + d*exp(sa + gb))/(exp(z) + S))
My current implementation is fairly similar to that, but I found that it dramatically slows down the time per batch to calculate all the sums individually during training. I am also concerned about stability when implementing softmax manually (is that an unnecessary concern?). Any suggestions on ways to optimize?