How to use torch.save and torch.load in OOP for RL?

Hello! I’m a beginner in OOP and RL and I need some advice for my connect 4 game :slight_smile:

First of all, don’t hesitate if you see anything shocking in my code ahah
But most of all I’m wondering where to save my training data record and where to load it so that it’s effective? I’m having a bit of trouble.

Thank you in advance!

import numpy as np
from colorama import Fore, Style
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import torch.nn.functional as F
import pickle
import os

repo = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(repo, "modeleIA.pth")

import matplotlib.pyplot as plt

if torch.cuda.is_available():
    device = torch.device('cuda') 
else:
    device = torch.device('cpu')

class Plateau:
    """ Classe représentant le plateau de jeu """

    def __init__(self, rows, columns):
        """ Constructeur de la classe, initialise le plateau et ses dimensions """
        self.rows = rows
        self.columns = columns
        self.plato = np.zeros((rows, columns))

    def colonne_check(self, col):
        """ Vérifie si la colonne est pleine """
        for i in range(self.rows):
            if self.plato[i][col] == 0:
                return True
        return False


    def placement_jeton(self, col, joueur):
        """ Place le jeton du joueur dans la colonne qu'il sélectionne """
        for i in np.r_[:self.rows][::-1]:
            if self.plato[i][col] == 0:
                self.plato[i][col] = joueur
                return True
        return False

    def affichage(self, joueur):
        """ Affiche le plateau de jeu à l'instant t """
        couleur = Fore.RED if joueur == 1 else Fore.YELLOW
        print(couleur + str(self.plato) + Style.RESET_ALL)

    def check_victoire(self):
        """ Vérifie si un joueur a gagné """
        rows, columns = self.rows, self.columns

        # vérification en ligne
        for r in np.r_[:rows]:
            for d in np.r_[:columns-3]:
                f = d + 4
                s = np.prod(self.plato[r, d:f])
                if s == 1 or s == 16:
                    return True

        # vérification en colonne
        for c in np.r_[:columns]:
            for d in np.r_[:rows-3]:
                f = d + 4
                s = np.prod(self.plato[d:f, c])
                if s == 1 or s == 16:
                    return True

        # vérification en diagonale (bas gauche vers haut droite)
        for r in np.r_[:rows-3]:
            for c in np.r_[:columns-3]:
                f = c + 4
                s = np.prod([self.plato[r+i, c+i] for i in range(4)])
                if s == 1 or s == 16:
                    return True

        # vérification en diagonale (haut gauche vers bas droite)
        for r in np.r_[3:rows]:
            for c in np.r_[:columns-3]:
                f = c + 4
                s = np.prod([self.plato[r-i, c+i] for i in range(4)])
                if s == 1 or s == 16:
                    return True
        return False

    def get_etat(self):
        """ Obtient l'état actuel du plateau sous forme de tableau 1D """
        return self.plato.flatten()

    def get_actions(self):
        """ Obtient les actions possibles à partir de l'état actuel du plateau """
        return np.where(self.plato[0] == 0)[0]

class Joueur:
    """ Classe représentant un joueur humain """

    def __init__(self, numero, max_choix):
        """ Initialise le joueur et son numéro et le nombre de choix possibles """
        self.numero = numero
        self.max_choix = max_choix

    def jouer(self, state, actions):
        """ Demande au joueur de choisir une colonne """
        while True:
            try:
                choix = int(input(f'Joueur {self.numero}, à vous de jouer (entre 1 et {self.max_choix}): ')) - 1
                if choix in actions:
                    return choix
                else:
                    print("Choix invalide. Essayez à nouveau.")
            except ValueError:
                print("Ce n'est pas un nombre. Essayez encore.")

class DQNAgent:
    # Dans le RL, l'agent DQN utilise une mémoire appelée "replay memory"
    # pour stocker les XP passées (état/action/rec/prochain état etc)
    # afin de les réutiliser lors de l'apprentissage

    # Initialisation de l'agent DQN
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=10000) #pareil que sur le github
        self.gamma = 0.95  # facteur d'actualisation, équilibre recomp im et future
        self.epsilon = 1.0  # taux d'exploration initial
        self.epsilon_min = 0.01  # taux d'exploration minimum
        self.epsilon_decay = 0.995  # taux de décroissance de l'exploration
        self.lr = 0.001  # taux d'apprentissage
        self.model = self._build_model()  # Construire le modèle de réseau neuronal
        self.batch_size = 64
        self.update_every = 5
       
        # réseau principal (d'évaluation) pour choisir les actions et réseau cible pour générer les Q-values cibles
        self.dqn_network = self._build_model().to(device)
        self.target_network = self._build_model().to(device)
        #pour l'instant même dim d'entrée, cachée (64) et de sortie

        self.optimizer = optim.Adam(self.dqn_network.parameters(), lr=self.lr)  # Optimiseur courrament utilisé en RL
        self.t_step = 0  # Compteur pour la mise à jour du réseau cible

    def charger_modele(self, chemin):
            self.modele.load_state_dict(torch.load(chemin))
            self.modele.eval()

    def _build_model(self):
            '''modele de réseau de neurones pour l'apprentissage'''
            model = nn.Sequential(
                nn.Linear(self.state_size, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, self.action_size)
            )
            return model

    # L'agent choisit l'action selon l'état et la politique epsilon-greedy
    def act(self, state, eps=0.1):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)  # On convertit l'état en tenseur
        self.dqn_network.eval()  # mode évaluation càd pas de mise à jour des poids

        with torch.no_grad():
            action_values = self.dqn_network(state)  # Calculer la Q-valeur pour chaque action
        self.dqn_network.train()  # Repasser le réseau en mode entrainement

        # politique epsilon-greedy (exploration/exploitation), à améliorer car
        #pour l'instant génère un nombre aléatoire
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())  # Choix de l'action avec la plus grande Q-valeur (greedy)
        else:
            return random.choice(np.arange(self.action_size))  # Choix d'une action aléatoire

    # Stock l'expérience dans la mémoire de remise en état
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # Prend une action et apprendre à partir de l'expérience
    def step(self, state, action, reward, next_state, done):
        self.remember(state, action, reward, next_state, done)  # Stocker l'expérience
        self.epsilon *= self.epsilon_decay #on multiplie par le facteur de décroissance
        self.learn()  # Apprendre de l'expérience

    # Apprentissage à partir de l'expérience (implémentation de l'équation de Belman)
    def learn(self):
        # mémoire de remise en état soit assez grande
        if len(self.memory) < self.batch_size:
            return

        # on échantillonne un batch d'expériences de taille aléatoire
        experiences = random.sample(self.memory, self.batch_size)
        # dézip les élements de l'échantillon 
        states, actions, rewards, next_states, dones = zip(*experiences)

        # Convertion des expériences numpy --> tenseur pytorch
        states = torch.from_numpy(np.vstack(states)).float().to(device)
        actions = torch.from_numpy(np.vstack(actions)).long().to(device)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(device)
        next_states = torch.from_numpy(np.vstack(next_states)).float().to(device)
        dones = torch.from_numpy(np.vstack(dones).astype(np.int64)).float().to(device)

        # Calcul les Q-valeurs cibles et attendues

        #on obtient les qvaleur prédites à partir du modèle cible
        Q_cible_next = self.target_network(next_states).detach().max(1)[0].unsqueeze(1)  # Q-valeurs cibles pour les prochains états
        
        #on calcule les Q cibles pour les états actuels
        Q_cible = rewards + (self.gamma * Q_cible_next * (1 - dones))  # Q-valeurs cibles pour les états actuels
        
        #on calcul les q attendus à partir du modèle
        Q_expected = self.dqn_network(states).gather(1, actions)  # Q-valeurs attendues pour les états actuels
        
        # Calcul de la perte et rétropropagation de l'erreur
        loss = F.mse_loss(Q_expected, Q_cible)  # Calcul la perte w/ MSE
        # On minimise la fonction de perte
        self.optimizer.zero_grad()  # Réinitialisation gradients
        loss.backward()  # Rétropropagation de l'erreur
        self.optimizer.step()  # Mise à jour des poids du réseau

        # Mise à jour du réseau cible on le copie du réseau DQN
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            self.target_network.load_state_dict(self.dqn_network.state_dict())

    def sauvegarder_modele(self, path):
        # Sauvegarde du modèle PyTorch avec pickle
        state = {
                'dqn_network_state_dict': self.dqn_network.state_dict(),
                'target_network_state_dict': self.target_network.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'memory': self.memory,
                'epsilon': self.epsilon,
            }

        #with open(chemin, "wb") as fichier:
           #pickle.dump(self.model, fichier)

    def charger_modele(self, path):
        # Chergament du modèle Pytorch avec pickle

        #with open(chemin, "rb") as fichier:
        #    self.model = pickle.load(fichier)
        if os.path.exists(path):
            state = torch.load(path)
            self.dqn_network.load_state_dict(state['dqn_network_state_dict'])
            self.target_network.load_state_dict(state['target_network_state_dict'])
            self.optimizer.load_state_dict(state['optimizer_state_dict'])
            self.memory = state['memory']
            self.epsilon = state['epsilon']
            print(f"Modèle chargé depuis {path}")
        else:
            print(f"Aucun modèle trouvé à {path}")

class IA:
    def __init__(self, numero, max_choix, state_size, agent=None):
        self.numero = numero
        self.max_choix = max_choix
        if agent is None:
            self.agent= DQNAgent(state_size, max_choix)
        else:
            self.agent= agent

    def jouer(self, state, actions):
        """ Choisit une colonne en utilisant l'agent IA """
        proba_victoires = self.calculer_proba_victoires(state, actions)
        # Trie en fonction des probabilités de victoire
        tri_indice_action = np.argsort(proba_victoires)
        # actions de la plus probable à la moins probable
        for indice in reversed(tri_indice_action):
            action = actions[indice]
            # Si l'action est possible, on la retourne
            if action in actions:
                return action
        #(ne devrait pas arriver), retourne une action aléatoire
        return np.random.choice(actions)


    def calculer_proba_victoires(self, state, actions):
        proba_victoires = np.zeros(len(actions))
        for i, action in enumerate(actions):
            next_state = state.copy()
            next_state[action] = self.numero
            proba_victoires[i] = self.agent.act(next_state)
        return proba_victoires

    def apprendre(self, state, action, reward, next_state, done):
        self.agent.step(state, action, reward, next_state, done)

class Jeu:
    """ Classe représentant le jeu en lui-même """

    def __init__(self, rows, columns, joueurs):
        """ Initialise le jeu et les joueurs """
        self.plato = Plateau(rows, columns)
        self.joueurs = joueurs

    def play(self):
            """ Lance le jeu et vérifie si un joueur a gagné ou si la partie est nulle """
            state = self.plato.get_etat()
            while True:
                for joueur in self.joueurs:
                    print(f"Joueur {joueur.numero}")
                    actions = self.plato.get_actions()
                    choix = joueur.jouer(state, actions)
                    self.plato.placement_jeton(choix, joueur.numero)
                    self.plato.affichage(joueur.numero)

                    if self.plato.check_victoire():
                        print(f"Joueur {joueur.numero} a gagné!")
                        self.plato.affichage(joueur.numero)  # Afficher le plateau final
                        return joueur.numero
                    elif np.all(self.plato.plato != 0):
                        print("Match nul!")
                        return None

                    state = self.plato.get_etat()

class Entrainement:
    
    def __init__(self, lignes, colonnes, episodes):
        self.lignes = lignes
        self.colonnes = colonnes
        self.episodes = episodes
        self.victoires = {1: 0, 2: 0, 'Nulles': 0}

        ### Liste vides pour enregistrer les données

        self.episodes_list = []
        self.victoires_joueur1 = []
        self.victoires_joueur2 = []
        self.parties_nulles = []

    def commencer(self):
        print("Choisissez le mode :")
        print("1. Jouer contre l'IA 1")
        print("2. Jouer contre l'IA 2")
        print("3. IA 1 vs IA 2")
        print("4. Jouer humain contre humain")
        print("5. Entraîner les deux IA entre elles")
        choix = int(input("Votre choix : "))

        if choix == 1:
            joueur_humain = Joueur(1, colonnes)
            agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA1 = IA(2, colonnes, colonnes * lignes, agent_IA1)
            joueurs = [joueur_humain, joueur_IA1]
        elif choix == 2:
            joueur_humain = Joueur(1, colonnes)
            agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
            joueurs = [joueur_humain, joueur_IA2]
        elif choix == 3:
            agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
            agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
            joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
            joueurs = [joueur_IA1, joueur_IA2]
        elif choix == 4:
            joueur_humain1 = Joueur(1, colonnes)
            joueur_humain2 = Joueur(2, colonnes)
            joueurs = [joueur_humain1, joueur_humain2]
        elif choix == 5:
            agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
            agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
            joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
            joueurs = [joueur_IA1, joueur_IA2]
            self.entrainement_IA(joueurs)
            return
        else:
            print("Mode invalide. Veuillez choisir 1, 2, 3, 4 ou 5.")
            return

        for i in range(self.episodes):
            print(f"Épisode {i+1}/{self.episodes}")
            jeu = Jeu(lignes, colonnes, joueurs)
            vainqueur = jeu.play()
            if vainqueur is not None:
                self.victoires[vainqueur] += 1
            else:
                self.victoires['Nulles'] += 1
            print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
            print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
            print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")

    def entrainement_IA(self, joueurs):
    
        for i in range(self.episodes):
            print(f"Épisode {i+1}/{self.episodes}")
            jeu = Jeu(lignes, colonnes, joueurs)
            vainqueur = jeu.play()
            if vainqueur is not None:
                self.victoires[vainqueur] += 1
            else:
                self.victoires['Nulles'] += 1

            # Sauvefarde des données
            self.episodes_list.append(i + 1)
            self.victoires_joueur1.append(self.victoires[1] / (i + 1))
            self.victoires_joueur2.append(self.victoires[2] / (i + 1))
            self.parties_nulles.append(self.victoires['Nulles'] / (i + 1))

            print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
            print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
            print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")

    # graphique inutile à changer + tard
        plt.plot(self.episodes_list, self.victoires_joueur1, label='Taux de victoire Joueur 1')
        plt.plot(self.episodes_list, self.victoires_joueur2, label='Taux de victoire Joueur 2')
        plt.plot(self.episodes_list, self.parties_nulles, label='Parties nulles')
        plt.xlabel('Épisodes')
        plt.ylabel('Taux de victoire')
        plt.legend()
        plt.savefig('graphique_evolution.png')
        plt.show()

#Initialise et démarre le jeu
if __name__ == "__main__":
    lignes, colonnes = 6, 7  # dimensions standard du puissance 4
    episodes = 100  # nombre d'épisodes pour l'entraînement
    entrainement = Entrainement(lignes, colonnes, episodes)
    entrainement.commencer()