Contribution: Stateless TicTacToe

Contributing a stateless TicTacToe implementation in pytorch RL.

Comes with a play script… to play it… just click on the matplotlib script using the mouse.

import torch
from torchrl.envs import EnvBase, ObservationTransform, TransformedEnv
from torchrl.envs.transforms.transforms import _apply_to_composite, step_mdp
from tensordict import TensorDict
from enum import IntEnum
from typing import Optional, Iterable
from torchrl.data import CompositeSpec, BoundedTensorSpec, \
    DiscreteTensorSpec, \
    UnboundedContinuousTensorSpec


class Player(IntEnum):
    white = 1
    black = -1


def gen_params(batch_size):
    td = TensorDict({
        'state': torch.zeros(3, 3),
        'player': torch.tensor([Player.black]),
        'action_mask': torch.tensor([True] * 9)
    }, batch_size=[])
    return td.expand(batch_size).contiguous()


def _step(td):
    td['player'] = - td['player']
    N = td.shape[0]
    batch_range = torch.arange(N)
    action = td['action']
    r, c = action // 3, action % 3
    td['state'][batch_range, r, c] = td['player'].to(dtype=td['state'].dtype)
    action_mask = (td['state'] == 0).reshape(N, 9)
    rows_win = td['state'].sum(dim=1).abs() == 3
    colums_win = td['state'].sum(dim=2).abs() == 3
    diags_win = torch.diagonal(td['state'], dim1=1, dim2=2).sum().abs() == 3
    rev_diags_win = torch.diagonal(td['state'].flip(-1), dim1=1, dim2=2).sum().abs() == 3
    full = td['state'].abs().sum(dim=(-1, -2)) == 9
    win = torch.any(rows_win | colums_win | diags_win | rev_diags_win, dim=-1, keepdim=True)
    terminated = win | full
    reward = torch.where(condition=win, input=td['player'], other=torch.zeros(N, 1))

    # switch player
    td['state'] = td['state']

    out = TensorDict({
        'state': td['state'],
        'player': td['player'],
        'action_mask': action_mask,
        'terminated': terminated,
        'reward': reward
    }, batch_size=td.shape)
    return out


def _make_spec(self, td_params):
    f_type = td_params['state'].dtype
    state = BoundedTensorSpec(
        minimum=-1,
        maximum=1,
        shape=torch.Size(td_params['state'].shape),
        dtype=f_type,
    )
    player = BoundedTensorSpec(minimum=-1, maximum=1, shape=(*td_params.shape, 1), dtype=torch.int64)
    action_mask = BoundedTensorSpec(minimum=0, maximum=1, shape=(*td_params.shape, 9), dtype=torch.bool)

    self.observation_spec = CompositeSpec(state=state, player=player, action_mask=action_mask, shape=td_params.shape)
    self.action_spec = DiscreteTensorSpec(9, shape=(*td_params.shape, 1))
    self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1), dtype=f_type)


def _set_seed(self, seed: Optional[int]):
    rng = torch.manual_seed(seed)
    self.rng = rng


def _reset(self, tensordict=None):
    batch_size = tensordict.shape if tensordict is not None else self.batch_size
    if tensordict is None or tensordict.is_empty():
        tensordict = self.gen_params(batch_size).to(self.device)
    if '_reset' in tensordict.keys():
        reset_state = self.gen_params(batch_size).to(self.device)
        reset_mask = tensordict['_reset'].squeeze(-1)
        for key in reset_state.keys():
            tensordict[key][reset_mask] = reset_state[key][reset_mask]
    return tensordict


class TicTacToe(EnvBase):
    batch_locked = False
    gen_params = staticmethod(gen_params)
    _make_spec = _make_spec
    _reset = _reset
    _step = staticmethod(_step)
    _set_seed = _set_seed

    def __init__(self, td_params=None, device="cpu", batch_size=None):

        if batch_size is None:
            batch_size = torch.Size([1])
        elif isinstance(batch_size, int):
            batch_size = torch.Size([batch_size])
        elif isinstance(batch_size, Iterable):
            batch_size = torch.Size(batch_size)
        elif isinstance(batch_size, torch.Size):
            pass
        else:
            assert False, "batch size must be torch.Size, list[int], or int"

        if td_params is None:
            td_params = self.gen_params(batch_size)
        super().__init__(device=device, batch_size=batch_size)
        self._make_spec(td_params)

    @staticmethod
    def player_perspective(state, player):
        return state * player


class RGBTransform(ObservationTransform):
    def __init__(self, out_key=None):
        out_keys = ['pixels'] if out_key is None else [out_key]
        super().__init__(in_keys=['state'], out_keys=out_keys)
        self.colors = {
            'white': torch.tensor([255, 255, 255], dtype=torch.uint8),
            'black': torch.tensor([0, 0, 0], dtype=torch.uint8),
            'grey': torch.tensor([128, 128, 128], dtype=torch.uint8)
        }

    def forward(self, tensordict):
        return self._call(tensordict)

    def _reset(self, tensordict, tensordict_reset):
        return self._call(tensordict_reset)

    def _call(self, td):
        td['pixels'] = torch.zeros(*td.shape, 3, 3, 3, dtype=torch.uint8)
        td['pixels'][td['state'] == 0] = self.colors['grey']
        td['pixels'][td['state'] == 1] = self.colors['white']
        td['pixels'][td['state'] == -1] = self.colors['black']
        td['pixels'] = td['pixels']
        return td

    @_apply_to_composite
    def transform_observation_spec(self, observation_spec):
        N, H, W = observation_spec.shape
        return BoundedTensorSpec(
            minimum=0,
            maximum=255,
            shape=torch.Size((N, 3, H, W)),
            dtype=torch.uint8,
            device=observation_spec.device
        )


if __name__ == '__main__':

    env = TicTacToe(batch_size=1)
    env = TransformedEnv(env)
    env.append_transform(RGBTransform())
    state = env.reset()

    import matplotlib.pyplot as plt

    # Set up the grid size
    rows, cols = 3, 3
    grid_size = 3  # Pixel size of the grid

    # Create a figure and axis
    fig, ax = plt.subplots()
    ax.set_xlim(0, grid_size)
    ax.set_ylim(0, grid_size)
    ax.set_aspect('equal')
    screen = ax.imshow(state['pixels'][0], extent=[0, 3, 0, 3], interpolation='nearest', origin='lower')

    # Draw the initial grid
    tile_size = grid_size / max(rows, cols)

    # Define the click event handler
    def on_click(event):
        global state
        if event.inaxes == ax:
            action = int(event.ydata) * 3 + int(event.xdata)
            if state['action_mask'][0, action]:
                state['action'] = torch.tensor([[action]])
                state = step_mdp(env.step(state), exclude_reward=False)
                screen.set_data(state['pixels'][0])
                fig.canvas.draw()
                if state['terminated'][0]:
                    reward = state['reward'][0].item()
                    if reward == 0:
                        print('draw')
                    else:
                        print(f'{Player(int(reward)).name} wins')
                    state = env.reset()
            else:
                print("nope")

    # Connect the click event to the event handler
    fig.canvas.mpl_connect('button_press_event', on_click)

    plt.show()
1 Like