I want to update the weights (theta) of my neural network using the update rule (t+1:next epoch, A is a constant, f and g are function of theta):
theta_{t+1} = theta_t + A f(theta) \nabla [ g(theta) + f(theta) ]
where nabla gives the derivative of the g and f with respect to theta.
Currently, the way I am implementing this update rule is (assume f and g are torch tensors with just one element and which depend of theta, the weights of the neural network):
optimizer.zero_grad()
loss = A * f.detach() * (g + f)
loss.backward()
optimizer.step()
Is this the correct way of implementing this weight update rule in PyTorch?