Adam implementation pytorch vs trax

I was just watching a lecture of a great researcher and engineer who is also trax’s author.
He claimed that the implementation of adam optimizer is complicated in tensorflow and keras but it is complicated also even in pytorch. He demonstrated the implementation in trax and seemed indeed quite simple.
Can you point me out the implementation of adam optimizer in pytorch, and if it can be implemented easier? What are the reasons that such a deep knowledge of pytorch is needed to implement such a simple thing?



You can find the implementation in this file.

The implementation is not as simple as it could be mainly for performance reasons: it uses inplace operations and re-uses buffers as much as possible to be able to get very good speed even though it is implemented in python.

But even beyond that, there is quite a bit of statistics to compute and bookeeping to do for Adam, meaning that the code is not trivial.


He used to define optimizer in Trax as

class Adam(Optimizer):

  def init(self, weights):
    m = jnp.zeros_like(weights)
    v = jnp.zeros_like(weights)
    return m, v

  def update(self, step, grads, weights, slots, opt_params):
    m, v = slots
    learning_rate = opt_params['learning_rate']
    weight_decay_rate = opt_params['weight_decay_rate']
    b1 = opt_params['b1']
    b2 = opt_params['b2']
    eps = opt_params['eps']
    m = (1 - b1) * grads + b1 * m  # First  moment estimate.
    v = (1 - b2) * (grads ** 2) + b2 * v  # Second moment estimate.
    mhat = m / (1 - b1 ** (step + 1))  # Bias correction.
    vhat = v / (1 - b2 ** (step + 1))
    new_weights = ((1 - weight_decay_rate) * weights - (
        learning_rate * mhat / (jnp.sqrt(vhat) + eps))).astype(weights.dtype)
    return new_weights, (m, v)

i would like to see , How will be our pytorch code if none of these performance and memory things used, like if i am building new optimizer on top of pytorch :thinking:

The update code could be copy pasted (almost) :stuck_out_tongue:
It seems like there is some boiler plate code that is missing in the trax example to help with creating the buffers (not sure where that code is).
But then that’s it, if you remove the input params validation and empty lines between the math.

We are actually refactoring the optimizers right now to move the update code outside of the state tracking. As you can see in the PR here the update function is not longer than this one.

1 Like

:smiley: Thanks for having to refractor to make as pythonic and researchable as possible. Im there to support.