I have some problems with algorithm realization

Hi all
I tried to use algorithm from this paper (code from paper) but my realization doesnt work well and I dont know why. Am I missed something?

results of learning:

agent.py:

import random
import numpy as np

import torch
from torch import optim
from torch.nn import Linear
import torch.nn as nn
import torch.autograd as ag

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=7, out_channels=14, kernel_size=5, stride=1, padding=1, bias=True),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=14, out_channels=45, kernel_size=5, stride=1, padding=1, bias=True),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=45, out_channels=60, kernel_size=5, stride=1, padding=1, bias=True),
            nn.LeakyReLU(),
            nn.Flatten()
        )

        self.fc = nn.Sequential(
            Linear(720, 180, bias=True),
            nn.LeakyReLU(),

            Linear(180, 180, bias=True),
            nn.LeakyReLU(),

            Linear(180, 8100, bias=True),
        )

        self.loss = nn.SmoothL1Loss()

    def forward(self, x):#, player):
        #print(x.shape)
        #x = self.conv(x)
        x = self.conv(x)
        #print(x.shape)
        x = self.fc(x)

        return x


class Randomizer:
    def get_action(self, legal_actions):
        legals = torch.argwhere(legal_actions==1)
        idx = np.random.choice(len(legals), 1)[0]
        action = legals[idx].tolist()
        return action

class Agent:
    def __init__(self, model_fn, device=None):
        if device==None:
            self.device = torch.device("cuda:0")
        else:
            self.device = device

        self.online_model = model_fn().to(self.device)
        self.optimizer = optim.Adam(self.online_model.parameters(), lr=0.0001)


    def get_random_action(self, legal_actions):

        legals = torch.argwhere(legal_actions==1)
        idx = np.random.choice(len(legals), 1)[0]
        action = legals[idx].tolist()
        return action

    def get_action(self, model, state, legal_actions, epsilon):
        is_explore = random.randint(1,100)/100<epsilon

        if is_explore:
            return self.get_random_action(legal_actions)

        else:
            reshaped_actions = model(state).reshape((90, 90))
            actions = torch.multiply(reshaped_actions, legal_actions)
            action = torch.argmax(actions).item()
            from_ = action//90
            to_ = action%90

            is_legal = legal_actions[from_, to_]==1
            if is_legal:
                return [from_, to_]
            else:
                return self.get_random_action(legal_actions)

    def calculate_q_loss(self, samples):

        states,actions,returns = samples

        self.optimizer.zero_grad()


        qvalues = torch.gather(input=self.online_model(states),dim=1,index=actions.unsqueeze(0).to(torch.int64)).squeeze(0)
        loss = nn.SmoothL1Loss()(qvalues,returns)

        loss.backward()
        self.optimizer.step()

        return loss.item()


buffer.py:


import torch
import numpy as np
from return_calculations import calculate_lambda_returns
from JanggiTools import int2state,ints2states,state2int

class Buffer:
    def __init__(self, gamma, model, max_steps, tau=0.05, max_length=1000000, batch_size=64, cache_size=160000, block_size=400, device=None):
        self.gamma = gamma

        self.model = model

        self.block_size = block_size
        self.cache_size = cache_size
        self.mixed_cache_idxs = np.empty(shape=self.cache_size,dtype=np.ndarray)

        self.update_b = int(cache_size*tau)
        self.update_bn = int(1/tau)
        self.update_bc = 0

        
        self.cached_states = torch.zeros((self.cache_size,7,9,10))
        self.cached_actions = torch.zeros((self.cache_size))
        self.cached_returns = torch.zeros((self.cache_size))

        self.batch_size = batch_size
        self.batch_count = 0

        self.max_length = max_length + block_size


        self.c = 0

        self.saved_states = np.empty(shape=self.max_length, dtype=np.ndarray)
        self.saved_actions = np.empty(shape=self.max_length, dtype=int)
        self.saved_rewards = np.empty(shape=self.max_length, dtype=np.int8)
        self.saved_dones = np.empty(shape=self.max_length, dtype=np.ndarray)
        self.saved_legals = np.empty(shape=self.max_length, dtype=np.ndarray)

        self.next_idx = 0
        self.entries = 0

        self.alpha = 1

        if device==None:
            self.device = torch.device("cuda:0")
        else:
            self.device = device


    def add_sample(self, state, action, reward, finished, legals):
        self.saved_states[self.next_idx] = state.to(torch.int8)
        self.saved_actions[self.next_idx] = action
        self.saved_rewards[self.next_idx] = reward
        self.saved_dones[self.next_idx] = finished
        self.saved_legals[self.next_idx] = legals.to(torch.int8)

        self.entries = min(self.entries+1, self.max_length)

        self.next_idx += 1
        self.next_idx = self.next_idx % self.max_length

    def refresh_cache(self):
        with torch.no_grad():
            self.model.to(torch.device("cpu"))
            #rrrr = 0

            errors = []
            self.cached_states = torch.zeros((self.cache_size,7,9,10))
            self.cached_actions = torch.zeros((self.cache_size))
            self.cached_returns = torch.zeros((self.cache_size))

            n_blocks = self.cache_size//self.block_size

            idxs = np.random.randint(0, self.entries-self.block_size, size=n_blocks)


            for i,idx in enumerate(idxs):
                bs = self.block_size

                states,actions,rewards,dones,legals = self.extract_block(idx)



                qvalues = self.model(states)


                max_qvalues = qvalues.multiply(legals).max(dim=1).values

                qvalues = torch.gather(input=qvalues,dim=1,index=torch.from_numpy(actions).unsqueeze(0).to(torch.int64)).squeeze(0)


                r = self.calculate_returns(
                    rewards=torch.tensor(rewards).float(),
                    qvalues=max_qvalues,
                    dones=torch.tensor(dones).float(),
                    discount=self.gamma
                )

                errors+=(r-qvalues).abs().tolist()

                self.cached_states[i:i+bs] = states[:-1]
                self.cached_actions[i:i+bs] = torch.tensor(actions)
                self.cached_returns[i:i+bs] = r
            errors = np.array(errors)
            priorities = self.get_priorities(errors)
            self.mixed_cache_idxs = np.random.choice(a=len(priorities),size=self.cache_size,replace=True,p=priorities)


    def extract_block(self,idx):
        start = idx
        bs = self.block_size 

        states = torch.stack(
            self.saved_states[start:start+bs+1].tolist()).to(torch.float32)
        actions = self.saved_actions[start:start+bs]
        rewards = self.saved_rewards[start:start+bs].astype(np.float16)
        dones = self.saved_dones[start:start+bs].astype(np.float16)
        legals = torch.stack(self.saved_legals[start:start+bs+1].tolist()).to(torch.float32)

        return states,actions,rewards,dones,legals
    def get_priorities(self,errors):
        priorities = errors/errors.sum()

        return priorities
    def calculate_returns(self,rewards,qvalues,dones,discount,k=21):
        returns = np.empty(
            shape=[k, len(rewards)],
            dtype=np.float32)

        for i in range(0,k):
            returns[i] = calculate_lambda_returns(rewards, qvalues, dones, discount, lambd=i/(k-1))

        return torch.from_numpy(np.median(returns, axis=0))
    def get_samples(self):
        b = self.batch_count * self.batch_size
        if b>=self.cache_size-self.batch_size:
            b = 0
            self.batch_count = 0

        idxs = self.mixed_cache_idxs[b:b+self.batch_size]

        states = self.cached_states[idxs]
        actions = self.cached_actions[idxs]
        returns = self.cached_returns[idxs]


        self.batch_count += 1

        return states.to(self.device), actions.to(self.device), torch.tensor(returns).to(self.device)

learning.py:

import os
import time
import datetime

import janggi

from config import ConfigManager
from agent import Agent, Net, Randomizer
from env import Env
from buffer import Buffer
import torch
import numpy as np

from torch.utils.tensorboard import SummaryWriter
import ptan

name = "1.2.0.9"
writer = SummaryWriter(f"learning-{name}", comment="-" + name, flush_secs=120)


def off():
    os.system("shutdown /s /t 0")


e = 200000
REFRESH_CACHE_TIME = 2500

LEARNING_TIME = 10000

time_ = time.time()
config_manager = ConfigManager("config.json")
print(config_manager.read())
config_ = config_manager.read()

yy_loss = []
yy_len = []
yy_reward = []

device = torch.device("cuda:0")

gamma = 0.9

agent = Agent(lambda: Net(), device=device)

net = Net().to(device)

agent.online_model = net
agent.copy_target()
agent.target_model.eval()
randomizer = Randomizer()

buffer = Buffer(gamma=gamma, model=agent.online_model, device=device, max_steps=e)

def decay_schedule(init_value, min_value, decay_ratio, max_steps, log_start=-2, log_base=10):
    decay_steps = int(max_steps * decay_ratio)
    rem_steps = max_steps - decay_steps
    values = np.logspace(log_start, 0, decay_steps, base=log_base, endpoint=True)[::-1]
    values = (values - values.min()) / (values.max() - values.min())
    values = (init_value - min_value) * values + min_value
    values = np.pad(values, (0, rem_steps), 'edge')
    return values

epsilons = decay_schedule(1, 0.05, 0.6, e)
tb_tracker = ptan.common.utils.TBMeanTracker(writer, batch_size=10)
tb_tracker.__enter__()
L = 0
for i in range(e):
    env = Env(device=device)
    opponent = agent.get_opponent()
    online_turn = np.random.choice([1,-1],1)[0]

    legals_ = env.get_legals()
    legals = []
    finished = False
    epsilon = epsilons[i-LEARNING_TIME] if i>LEARNING_TIME else 1
    states = []
    actions = []
    rewards = []
    next_states = []
    finishes = []
    masks = []
    agent.online_model.eval()
    w = 0
    while not finished:
        state = env.get_state()

        if env.game.to_play()==online_turn:
            action = agent.get_action(model=agent.online_model, state=state.unsqueeze(0).to(device), legal_actions=legals_, epsilon=epsilon)
            legals.append(legals_.reshape(8100).cpu())
            states.append(state)

            reward, finished, legals_ = env.step(*action)


            actions.append(action[0]*9+action[1])
            rewards.append(reward)
            finishes.append(finished)

            if reward==1: w=1
        else:
            action = randomizer.get_action(legals_)
            reward_, finished, legals_ = env.step(*action)

            if len(rewards)>0:
                rewards[-1] -= reward_


    for j in range(len(rewards)):
        buffer.add_sample(states[j], actions[j], rewards[j], finishes[j], legals[j])


    if i>=LEARNING_TIME:
        if i==LEARNING_TIME or i%REFRESH_CACHE_TIME==0:
            buffer.refresh_cache()

        agent.online_model.to(device)

        samples = buffer.get_samples()
        agent.online_model.train()
        L = agent.calculate_q_loss(samples)


        tb_tracker.track("data/loss", L, i)
        tb_tracker.track("data/len",env.c,i)
        tb_tracker.track("data/wins",w,i)
        tb_tracker.track("data/reward",sum(rewards),i)

    if i%100==0: print(f"{i}: {L}")

    if config_["save_middle_model"] and i%config_["save_middle_model_steps"]==0:
        state = {'info': "JanggiBotV1",  
                    'date': datetime.datetime.now(),  
                    'epochs': i,
                    'model': agent.online_model.state_dict(),  
                    'optimizer': agent.optimizer.state_dict()}  
        str_dir = f'Models/JanggiBot-1.2.0.1-{i}.pt'
        torch.save(state, str_dir)


state = {'info': "JanggiBotV1",
            'date': datetime.datetime.now(), 
            'epochs': i,
            'model': agent.online_model.state_dict(),
            'optimizer': agent.optimizer.state_dict()}  
torch.save(state, f'Models/JanggiBot-1.2.0.1.pt')

if config_manager.read()["save_time"]:
    f = open("time.txt", 'w')
    print(time.time() - time_)
    f.write(str(time.time() - time_))
    f.close()
if config_manager.read()["turn_off"]:
    off()

tb_tracker.__exit__()

return_calculations.py:

import numpy as np

def calculate_lambda_returns(rewards, qvalues, dones, discount, lambd):
    dones = dones
    qvalues[-1] *= (1.0 - dones[-1])

    #print(rewards,qvalues,dones)
    lambda_returns = rewards + (discount * qvalues[1:])
    for i in reversed(range(len(rewards) - 1)):
        a = lambda_returns[i] + (discount * lambd) * (lambda_returns[i+1] - qvalues[i+1])
        b = rewards[i]
        lambda_returns[i] = (1.0 - dones[i]) * a + dones[i] * b
    return lambda_returns