Simple policy gradient application - wrong learning

Hello,

I’m trying to learn how to implement a simple RL approach in a toy example. The goal is to be able to generate a sequence of tokens (actions) = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. A state is the sequence of tokens generates so far, with 0 as padding. E.g., [2, 8, 3, 0, 0, 0, 0, 0, 0, 0] after the first 3 steps. I assume I can only get a reward for a complete sequence.

The code is the following:

import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical

SEQ_LEN = 10
EPOCHS = 10000
HIDDEN_SIZE = 50
ROLLOUT_NUM = 10
N_TOKENS = 11
LR = 1e-3
EPS = np.finfo(np.float32).eps.item()

model = nn.Sequential(nn.Linear(SEQ_LEN, HIDDEN_SIZE), nn.Linear(HIDDEN_SIZE, N_TOKENS), nn.LogSoftmax(-1))
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# Comute reward: +1 for each token in the correct position. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -> RMAX
def get_rw(beh):
    r = 0.
    for i in range(SEQ_LEN):
        if beh[i] == i + 1:
            r += 1
        else:
            break
    return torch.tensor(r)

for epoch in range(EPOCHS):
    # Initial state is [0, ..., 0]
    rb = [0.] * SEQ_LEN
    loss = 0.
    rewards = []
    l_probs = []
    for c in range(SEQ_LEN):
        # Process the state
        out = model(torch.tensor(rb))
        # Sample a token
        cd = Categorical(logits=out)
        r = cd.sample()
        l_probs.append(cd.log_prob(r))
        # Update the state representation
        rb[c] = r.item()
        with torch.no_grad():
            # Perform rollouts from the current state
            rew = 0.
            for _ in range(ROLLOUT_NUM):
                i_rb = list(rb)
                # Produce tokens to complete the sequence
                for u in range(c + 1, SEQ_LEN):
                    i_out = model(torch.tensor(i_rb))
                    i_cd = Categorical(logits=i_out)
                    i_rb[u] = i_cd.sample().item()
                # Reward for a complete rollout
                rew += get_rw(i_rb)
        rewards.append(rew / ROLLOUT_NUM)
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + EPS)
    # Minimize the negative of average reward
    loss = -(torch.stack(l_probs) * rewards).sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Produce a sequence to check
    with torch.no_grad():
        bh = [0.] * SEQ_LEN
        for i in range(SEQ_LEN):
            out = model(torch.tensor(bh))
            cd = Categorical(logits=out)
            r = cd.sample()
            bh[i] = r.item()
        print(bh)

The problem is that with time the sequence produced will be constant, e.g., [6, 6, 6, 6, 6, 6, 6, 6, 6, 6].

I tried different learning rates, network size etc… but the behavior doesn’t change. I guess I’m just missing some basic knowledge about the technique I’m trying to apply. Any help will be gratly appreciated.

Thanks!