Can someone help me fix my simple RL-model?

This is an attempt at making a reinforcement learning model learn tic tac toe (which is nice because it’s clear when the project has succeeded because it is a solvable game). After actually making it run and properly handling all the tensors it fails to progress. In order to avoid making illegal moves, I decided to code it so that you lose the game if you chose an already occupied field. I know I could also mask the probabilities to reinforce the possible moves but it should work this way. However, in 20’000 games every game ended with 2 moves because of an illegal move, which suggests that the model is actually reinforcing that behavior very quickly but I do not know why. I tried simply reversing the reward but that did not work either. Other design decisions:

  • Input layer is 10 nodes with the last being 1 if it’s player 1’s turn and -1 if it’s player 2.
  • I randomly choose a player and save its moves to apply the reward to. I wasn’t sure how to elegantly handle the fact that the model plays against itself instead of an environment, so I decided to make each episode only learn either side.

Here’s the code:

class TicTacToe:
    def __init__(self):
        self.board = torch.zeros((3, 3))  # Initialize the game board as a PyTorch tensor
        self.player = 1  # Player 1 goes first        
        self.winner = None
    
    def step(self, action):
        done = self.make_move(action)  # Make the move
        if done:
            result = 'illegal move'
            return done, (self.player*-1), result
        done = self.check_result()
        self.player = -self.player  # Switch players
        if done:
            result = 'win player 1' if self.winner == self.player else 'draw' if self.winner == None else 'win player 2'
        else:
            result = 0
        return done, self.winner, result 
    
    def reset(self):
        self.board = torch.zeros((3, 3))
        self.player = 1
        self.winner = None  # Reset the reward to 0

    def make_move(self, action):
        row, col = action // 3, action % 3
        if self.board[row][col] == 0: 
            self.board[row][col] = self.player
            return False
        else:
            self.winner = self.player
            return True
    
    def check_result(self):
        # Check rows and columns
        for i in range(3):
            if torch.all(self.board[i] == self.player) or torch.all(self.board[:, i] == self.player):
                self.winner = self.player
                return True
        # Check diagonals
        if torch.all(torch.diag(self.board) == self.player) or torch.all(torch.diag(torch.fliplr(self.board)) == self.player):
            self.winner = self.player
            return True
        #check draw
        self.winner = None
        return not torch.any(self.board == 0)


# Define the reinforcement learning model

class TicTacToeModel(torch.nn.Module):
    def __init__(self):
        super(TicTacToeModel, self).__init__()
        self.fc1 = torch.nn.Linear(10, 64)  # 10 input features (board positions + active player) -> 64 hidden units
        self.fc2 = torch.nn.Linear(64, 32)  # 64 hidden units -> 32 hidden units
        self.fc3 = torch.nn.Linear(32, 9)  # 32 hidden units -> 9 output features (probabilities of each move)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.sigmoid(self.fc3(x))  # Use sigmoid instead of softmax to output probabilities
        return x

# Create the game environment and model
env = TicTacToe()
model = TicTacToeModel()

# Define the loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()  # Cross-entropy loss can be used as a policy gradient loss
optimizer = torch.optim.Adam(model.parameters())

num_episodes = 20000

# Start training loop
for episode in range(num_episodes):
    env.reset()
    done = False
    log_probs = []
    
    # set player to 1 or -1
    active_player = random.choice([-1,1])
    p = 0
    while not done:
        player_tensor = torch.tensor([env.player], dtype=torch.float32)
        state = torch.cat((env.board.view(9), player_tensor), dim=0)
        #state = torch.cat((env.board.view(1,9), player_tensor), dim=1)
        
        # Get action probabilities from the model
        action_probs = model(state)
        
        # Sample action from the probability distribution
        action = torch.argmax(action_probs).item()
        if active_player == env.player:
            log_probs.append(torch.log(action_probs[action]))
        
        # Step the environment
        done, winner, result = env.step(action)
        #print(env.board)
        p+=1
        if p == 10:
            break
    # Compute the reward for the episode
    if result == 'illegal move':
        reward = -1
    else:
        reward = 1 if winner == active_player else 0 if winner == None else -1

    # Compute the loss
    loss = -torch.sum(torch.stack(log_probs)) * reward   
    
    # Update the model
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()