Implementing Sarsa weight updates

(Douglas Teoh) #1

I have the following equation:

w <- w + a[R + gamma * q(S', A', w) - q(S, A, w)] * gradient of q(S, A, w)

My code:

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import backward, Variable

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(5, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        return F.sigmoid(self.fc2(x))

env = gym.make('CartPole-v0')
policy = Policy()
optimizer = optim.Adam(policy.parameters())

s = env.reset()
a = 0

s_next, reward, done, _ = env.step(a)
a_next = 1

x = torch.from_numpy(np.append(s, a)).float().unsqueeze(0)
x = Variable(x)
q_sa = policy(x)

x_next= torch.from_numpy(np.append(s_next, a_next)).float().unsqueeze(0)
x_next = Variable(x_next)
q_sa_next = policy(x_next)

alpha = 1
gamma = 1
update = alpha * (reward + gamma * q_sa_next - q_sa)

Am I implementing update and propagating things correctly? I don’t quite understand how to handle the gradient part of the equation.

(Hugh Perkins) #2

q_sa_next should have gradients stripped. ie call Variable( or q_sa_next.detach().

if it was me, I’d simply call .data on it, to strip out the gradients. And then add Variable into the update line:

update = Variable(alpha * (reward + gamma * q_sa_next) - q_sa

I never quite figured out how to backpropagate update directly though, so I do it like this:

crit = nn.MSELoss()
loss = crit(q_sa, Variable(reward + gamma * q_sa_next))

(Edit: I guess that to backpropagate update, we’d do something like q_sa.backward(update). But as I say I just lazily do it using crit for now, and never tested this approach)

(Douglas Teoh) #3

Thank you very much, I’ve got a much better understanding now.

The only thing I’m doing slightly different to try to reproduce the formula is to use nn.L1Loss(size_average=False).

(Hugh Perkins) #4

Interesting. Thanks! :slight_smile:

(Hugh Perkins) #5

Hmmm, so, you got me worried that I was using the wrong loss function :slight_smile:

So, I worked through it a bit

… and I’m fairly sure of two things:

  1. should be MSE loss. Though there might be some issue with my reasoning above?, and
  2. to directly backprop delta without needing a loss function, we can do simply:



(Douglas Teoh) #6

Okay, I misunderstood, I did not consider the mathematical justification.

I just thought that R + gamma * q(S', A', w) - q(S, A, w) is about moving the weights of q(S, A, w) towards R + gamma * q(S', A', w), and that implementing it with any loss function that encoded this difference would do a similar thing.

I don’t quite follow how:

(1/2) * Q_w^2 - rQ_w - yQ'_w + C

can turn into:

(1/2) * (Q_w - (r + yQ'_w))^2 + C