Hello,
I’m trying to learn how to implement a simple RL approach in a toy example. The goal is to be able to generate a sequence of tokens (actions) = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. A state is the sequence of tokens generates so far, with 0 as padding. E.g., [2, 8, 3, 0, 0, 0, 0, 0, 0, 0] after the first 3 steps. I assume I can only get a reward for a complete sequence.
The code is the following:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical
SEQ_LEN = 10
EPOCHS = 10000
HIDDEN_SIZE = 50
ROLLOUT_NUM = 10
N_TOKENS = 11
LR = 1e-3
EPS = np.finfo(np.float32).eps.item()
model = nn.Sequential(nn.Linear(SEQ_LEN, HIDDEN_SIZE), nn.Linear(HIDDEN_SIZE, N_TOKENS), nn.LogSoftmax(-1))
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# Comute reward: +1 for each token in the correct position. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -> RMAX
def get_rw(beh):
r = 0.
for i in range(SEQ_LEN):
if beh[i] == i + 1:
r += 1
else:
break
return torch.tensor(r)
for epoch in range(EPOCHS):
# Initial state is [0, ..., 0]
rb = [0.] * SEQ_LEN
loss = 0.
rewards = []
l_probs = []
for c in range(SEQ_LEN):
# Process the state
out = model(torch.tensor(rb))
# Sample a token
cd = Categorical(logits=out)
r = cd.sample()
l_probs.append(cd.log_prob(r))
# Update the state representation
rb[c] = r.item()
with torch.no_grad():
# Perform rollouts from the current state
rew = 0.
for _ in range(ROLLOUT_NUM):
i_rb = list(rb)
# Produce tokens to complete the sequence
for u in range(c + 1, SEQ_LEN):
i_out = model(torch.tensor(i_rb))
i_cd = Categorical(logits=i_out)
i_rb[u] = i_cd.sample().item()
# Reward for a complete rollout
rew += get_rw(i_rb)
rewards.append(rew / ROLLOUT_NUM)
rewards = torch.tensor(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + EPS)
# Minimize the negative of average reward
loss = -(torch.stack(l_probs) * rewards).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Produce a sequence to check
with torch.no_grad():
bh = [0.] * SEQ_LEN
for i in range(SEQ_LEN):
out = model(torch.tensor(bh))
cd = Categorical(logits=out)
r = cd.sample()
bh[i] = r.item()
print(bh)
The problem is that with time the sequence produced will be constant, e.g., [6, 6, 6, 6, 6, 6, 6, 6, 6, 6].
I tried different learning rates, network size etc… but the behavior doesn’t change. I guess I’m just missing some basic knowledge about the technique I’m trying to apply. Any help will be gratly appreciated.
Thanks!