I’m trying to learn how to implement a simple RL approach in a toy example. The goal is to be able to generate a sequence of tokens (actions) = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. A state is the sequence of tokens generates so far, with 0 as padding. E.g., [2, 8, 3, 0, 0, 0, 0, 0, 0, 0] after the first 3 steps. I assume I can only get a reward for a complete sequence.
The code is the following:
import torch import torch.nn as nn import numpy as np from torch.distributions import Categorical SEQ_LEN = 10 EPOCHS = 10000 HIDDEN_SIZE = 50 ROLLOUT_NUM = 10 N_TOKENS = 11 LR = 1e-3 EPS = np.finfo(np.float32).eps.item() model = nn.Sequential(nn.Linear(SEQ_LEN, HIDDEN_SIZE), nn.Linear(HIDDEN_SIZE, N_TOKENS), nn.LogSoftmax(-1)) optimizer = torch.optim.Adam(model.parameters(), lr=LR) # Comute reward: +1 for each token in the correct position. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -> RMAX def get_rw(beh): r = 0. for i in range(SEQ_LEN): if beh[i] == i + 1: r += 1 else: break return torch.tensor(r) for epoch in range(EPOCHS): # Initial state is [0, ..., 0] rb = [0.] * SEQ_LEN loss = 0. rewards =  l_probs =  for c in range(SEQ_LEN): # Process the state out = model(torch.tensor(rb)) # Sample a token cd = Categorical(logits=out) r = cd.sample() l_probs.append(cd.log_prob(r)) # Update the state representation rb[c] = r.item() with torch.no_grad(): # Perform rollouts from the current state rew = 0. for _ in range(ROLLOUT_NUM): i_rb = list(rb) # Produce tokens to complete the sequence for u in range(c + 1, SEQ_LEN): i_out = model(torch.tensor(i_rb)) i_cd = Categorical(logits=i_out) i_rb[u] = i_cd.sample().item() # Reward for a complete rollout rew += get_rw(i_rb) rewards.append(rew / ROLLOUT_NUM) rewards = torch.tensor(rewards) rewards = (rewards - rewards.mean()) / (rewards.std() + EPS) # Minimize the negative of average reward loss = -(torch.stack(l_probs) * rewards).sum() optimizer.zero_grad() loss.backward() optimizer.step() # Produce a sequence to check with torch.no_grad(): bh = [0.] * SEQ_LEN for i in range(SEQ_LEN): out = model(torch.tensor(bh)) cd = Categorical(logits=out) r = cd.sample() bh[i] = r.item() print(bh)
The problem is that with time the sequence produced will be constant, e.g., [6, 6, 6, 6, 6, 6, 6, 6, 6, 6].
I tried different learning rates, network size etc… but the behavior doesn’t change. I guess I’m just missing some basic knowledge about the technique I’m trying to apply. Any help will be gratly appreciated.