Model Boilerplate for a Simple DQN

Could someone please provide me a simple, bare-bones DQN training loop? I have looked at the examples for the [TorchRL documentation]( TorchRL — torchrl 0.0 documentation ), but it becomes overwhelmingly complex, and after a few months spent on it, I am not 100% sure that I am doing the right thing. I am interested in solving discrete grid problems with the [Atari model]( https://docs.pytorch.org/rl/main/\_modules/torchrl/modules/models/models.html#ConvNet.default_atari_dqn ) with the most basic approach possible, no replay buffers, no delayed double DQNs or anything of the sort, just the core mechanics (maybe push to the tutorial page later?).

I apologize in advance if this is already available somewhere. In case it is, please point me to it.

Here’s a simple DQN with no environment. Rock, Paper, Scissors. It uses your last 5 choices to extrapolate what your next choice will be. 116 lines of code.

import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

n = 5  # history length


class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(3 * n, 32), nn.ReLU(), nn.Linear(32, 3))

    def forward(self, x):
        return self.fc(x)


net = DQN()
target_net = DQN()
target_net.load_state_dict(net.state_dict())
optimizer = optim.Adam(net.parameters(), lr=0.01)
memory = deque(maxlen=2000)
gamma = 0.99
epsilon = 1.0
batch_size = 32
moves = ['rock', 'paper', 'scissors']


def select_action(state):
    if random.random() < epsilon:
        return random.randint(0, 2)
    with torch.no_grad():
        return net(state).argmax().item()


def optimize():
    if len(memory) < batch_size:
        return
    batch = random.sample(memory, batch_size)
    states = torch.cat([s for s, a, r, ns in batch])
    actions = torch.cat([a for s, a, r, ns in batch])
    rewards = torch.cat([r for s, a, r, ns in batch])
    next_states = torch.cat([ns for s, a, r, ns in batch])
    q_values = net(states).gather(1, actions)
    next_q = target_net(next_states).max(1)[0].detach()
    expected = rewards + gamma * next_q
    loss = nn.MSELoss()(q_values, expected.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


state = torch.zeros(1, 3 * n)
steps = 0
wins = 0
ties = 0
losses = 0

print("Rock-Paper-Scissors with DQN AI")
print("Commands: r=rock, p=paper, s=scissors, q=quit")

while True:
    # Select AI action based on current state (hidden from user)
    action = select_action(state)

    user_in = input("\nYour move (r/p/s/q): ").lower()
    if user_in == 'q':
        break

    user_map = {'r': 0, 'p': 1, 's': 2}
    if user_in not in user_map:
        print("Invalid input. Use r, p, s, or q.")
        continue

    user = user_map[user_in]

    # Now reveal both choices
    print(f"You chose: {moves[user]}")
    print(f"AI chose: {moves[action]}")

    if action == user:
        reward = 0
        ties += 1
        print("Result: Tie!")
    elif (action - user) % 3 == 1:
        reward = 1
        wins += 1
        print("Result: AI wins!")
    else:
        reward = -1
        losses += 1
        print("Result: You win!")

    one_hot = torch.nn.functional.one_hot(torch.tensor([user]), num_classes=3).float()
    next_state = torch.cat([state[:, 3:], one_hot], dim=1)

    memory.append(
        (state, torch.tensor([[action]], dtype=torch.long), torch.tensor([reward], dtype=torch.float32), next_state))

    state = next_state
    optimize()
    steps += 1

    if steps % 10 == 0:
        target_net.load_state_dict(net.state_dict())

    epsilon = max(0.01, epsilon * 0.99)

    if steps % 10 == 0:
        total = wins + ties + losses
        print(f"\nStats after {steps} turns:")
        print(f"Wins: {wins} ({wins / total * 100:.1f}%)")
        print(f"Ties: {ties} ({ties / total * 100:.1f}%)")
        print(f"Losses: {losses} ({losses / total * 100:.1f}%)")
        print(f"Epsilon: {epsilon:.3f}")
1 Like

Newbie question: The DQN tutorial is pretty similar, but the PPO tutorial uses TorchRL. Is there a specific reason not to use TorchRL for DQN?

The DQN tutorial was written in 2017. TorchRL didn’t start dev until 2021 with first launch in 2022.