Gradients with Argmax in PyTorch

I was trying to implement the model in this paper “Dynamic Coattention Networks for QA” in PyTorch, and noticed that many of my parameters were not getting trained at all. After some debugging, the problem seems to occur because of an argmax operation in the decoder (on page 4 of the paper). The output (i.e. the second return value of torch.max) has require_grad as false, which makes sense since argmax is not differentiable. However, the author of the paper trains his model using a basic Adam optimizer – how is this possible? What work-around would allow me to do the same?

The paper’s author uses Chainer, which shouldn’t be that different from PyTorch, right?

Also, I tried implementing this model in TensorFlow, and it did work as written – why was that the case? Does tensorflow implement argmax differently, in some ‘soft’ way?

Thanks!

1 Like

Do you have a pointer to the model in different frameworks? The derivative of argmax is zero nearly everywhere, so it doesn’t seem likely that you can back-propagate through it in a way that is useful.

I haven’t been able to find it implemented anywhere “officially”, the version i wrote in Tensorflow just called “tf.argmax”. It’s completely possible that there was something wrong with that model though, because the performance wasn’t really up to scratch. Here’s the relevant decoder code from my implementation:

def decoder_loop(iteration, s, e, us, ue, h_c_state, h_m_state, loss):
    decoder_endpoint_input = tf.concat([us, ue], axis=-1)
    _, (h_c_state, h_m_state)  = decoder_lstm_cell(decoder_endpoint_input, (h_c_state, h_m_state), scope=scope)
    h = h_m_state
    alpha = hmn(U, h, us, ue, "decoder_start_hmn")
    s = tf.argmax(alpha, axis=-1)
    us = tf.gather_nd(U, [[i, tens] for (i, tens) in enumerate(tf.unstack(s))])
    beta = hmn(U, h, us, ue, "decoder_end_hmn")
    e = tf.argmax(beta, axis=-1)
    ue = tf.gather_nd(U, [[i, tens] for (i, tens) in enumerate(tf.unstack(e))])
    loss += tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=alpha, labels=answer_starts))
    loss += tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=beta, labels=answer_ends))
    return (tf.add(iteration, 1), s, e, us, ue, h_c_state, h_m_state, loss)

_, s, e, _, _, _, _, loss = tf.while_loop(looping_cond, decoder_loop, (iteration_0, s_0, e_0, us_0, ue_0, h_c_state_0, h_m_state_0, loss_0), parallel_iterations=1)
tf.summary.scalar("loss", loss)

I still haven’t really gotten to the bottom of this issue - does anyone have any insight?

Do you try softargmax?