I have the following equation:

``````w <- w + a[R + gamma * q(S', A', w) - q(S, A, w)] * gradient of q(S, A, w)
``````

My code:

``````import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.fc1 = nn.Linear(5, 128)
self.fc2 = nn.Linear(128, 1)

def forward(self, x):
x = F.tanh(self.fc1(x))
return F.sigmoid(self.fc2(x))

env = gym.make('CartPole-v0')
policy = Policy()

s = env.reset()
a = 0

s_next, reward, done, _ = env.step(a)
a_next = 1

x = torch.from_numpy(np.append(s, a)).float().unsqueeze(0)
x = Variable(x)
q_sa = policy(x)

x_next= torch.from_numpy(np.append(s_next, a_next)).float().unsqueeze(0)
x_next = Variable(x_next)
q_sa_next = policy(x_next)

alpha = 1
gamma = 1
update = alpha * (reward + gamma * q_sa_next - q_sa)
update.backward()
optimizer.step()
``````

Am I implementing `update` and propagating things correctly? I don’t quite understand how to handle the gradient part of the equation.

`q_sa_next` should have gradients stripped. ie call `Variable(q_sa_next.data)` or `q_sa_next.detach()`.

if it was me, I’d simply call `.data` on it, to strip out the gradients. And then add `Variable` into the `update` line:

``````update = Variable(alpha * (reward + gamma * q_sa_next) - q_sa
``````

I never quite figured out how to backpropagate `update` directly though, so I do it like this:

``````crit = nn.MSELoss()
loss = crit(q_sa, Variable(reward + gamma * q_sa_next))
loss.backward()
``````

(Edit: I guess that to backpropagate `update`, we’d do something like `q_sa.backward(update)`. But as I say I just lazily do it using crit for now, and never tested this approach)

Thank you very much, I’ve got a much better understanding now.

The only thing I’m doing slightly different to try to reproduce the formula is to use `nn.L1Loss(size_average=False)`.

1 Like

Interesting. Thanks! Hmmm, so, you got me worried that I was using the wrong loss function So, I worked through it a bit https://github.com/hughperkins/pub-prototyping/blob/master/maths/sarsa_updates.ipynb

… and I’m fairly sure of two things:

1. should be MSE loss. Though there might be some issue with my reasoning above?, and
2. to directly backprop `delta` without needing a loss function, we can do simply:

`q.backward(-delta)`

Thoughts?

Okay, I misunderstood, I did not consider the mathematical justification.

I just thought that `R + gamma * q(S', A', w) - q(S, A, w)` is about moving the weights of `q(S, A, w)` towards `R + gamma * q(S', A', w)`, and that implementing it with any loss function that encoded this difference would do a similar thing.

``````(1/2) * Q_w^2 - rQ_w - yQ'_w + C
``````(1/2) * (Q_w - (r + yQ'_w))^2 + C