When I look at the source code of torch.optim.Adam, it looks like there are two functions that perform the bulk of the computations: _single_tensor_adam and _multi_tensor_adam. However, these are not member functions of torch.optim.Adam, but instead only seem to be called in the function adam, which is in turn called in the method step in torch.optim.Adam.
Is there a recommended way to modify the Adam optimizer? Is it recommended to derive from torch.optim.Adam and override the step function, or is better to create a completely new optimizer from scratch, by deriving from torch.optim.Optimizer? And what is the purpose of the functions adam, _single_tensor_adam and _multi_tensor_adam? Are those something I should implement too?
adam is the functional API which is called in torch.optim.Adam.step as you described. _single_tensor_adam is the default implementation, which you could use as a template for your custom implementation. _multi_tensor_adam is a performance optimization, which you could ignore.
Thank you, @ptrblck. When you say that I could use _single_tensor_adam as a template for my custom implementation, do you mean that I could override the step method with a method that looks like _single_tensor_adam?
Yes, you can add the body of _single_tensor_adam to your step method directly or also call it via the functionaladam call. For the sake of simplicity you could directly reuse _single_tensor_adam’s logic in step.