Hi, I believe RProp implementation is slightly incorrect (according to https://www.cs.cmu.edu/~bhiksha/courses/deeplearning/Fall.2016/pdfs/Rprop.pdf).
here is an example on a quadratic function f(x) = x^2
param = torch.tensor(2., requires_grad=True)
opt = torch.optim.Rprop([param], 1)
last_param = param.clone()
loss = param**2
print(f'param = {round((last_param).item(), 3)}, loss = {round(loss.item(), 3)}')
for i in range(10):
opt.zero_grad()
loss.backward()
opt.step()
loss = param**2
update = round((param-last_param).item(), 3)
last_param = param.clone()
print(f'param = {round((last_param).item(), 3)}, loss = {round(loss.item(), 3)}, grad = {round(param.grad.item(), 3)}, {update = }')
those are the first six output lines:
param = 2.0, loss = 4.0
param = 1.0, loss = 1.0, grad = 4.0, update = -1.0
param = -0.2, loss = 0.04, grad = 2.0, update = -1.2
param = -0.2, loss = 0.04, grad = -0.4, update = 0.0
param = 0.4, loss = 0.16, grad = -0.4, update = 0.6
param = 0.4, loss = 0.16, grad = 0.8, update = 0.0
notice the last line. When the derivative changes, Rprop is supposed to perform a ‘backtracking’ weight-step.
If the partial derivative changes sign, i.e. the previous step was too large and the minimum was missed, the previous weight-update is reverted. Due to that 'backtracking' weight-step, the derivative is supposed to change its sign once again in the following step.
However pytorch doesn’t perform a backtracking step and instead doesn’t perform a step at all with that weight.
Update -1.2 moved parameter from 1 to -0.2, which changes the sign of the gradient to minus. Rprop is supposed to undo that update and move back to 1, while multiplying the update by n-. The adaption is disabled because that changes sign of the gradient back to plus. So the next update will be -0.6, which will move param to 0.4. However pytorch doesn’t perform the backtracking step, and keeps parameter at -0.2.
Here is what I believe it should look like
param = 2.0, loss = 4.0
param = 1.0, loss = 1.0, grad = 4.0, update = -1.0
param = -0.2, loss = 0.04, grad = 2.0, update = -1.2
param = 1.0, loss = 1.0, grad = -0.4, update = 1.2
param = 0.4, loss = 0.16, grad = 2.0, update = -0.6
param = -0.32, loss = 0.102, grad = 0.8, update = -0.72