Contributing a stateless TicTacToe implementation in pytorch RL.
Comes with a play script… to play it… just click on the matplotlib script using the mouse.
import torch
from torchrl.envs import EnvBase, ObservationTransform, TransformedEnv
from torchrl.envs.transforms.transforms import _apply_to_composite, step_mdp
from tensordict import TensorDict
from enum import IntEnum
from typing import Optional, Iterable
from torchrl.data import CompositeSpec, BoundedTensorSpec, \
DiscreteTensorSpec, \
UnboundedContinuousTensorSpec
class Player(IntEnum):
white = 1
black = -1
def gen_params(batch_size):
td = TensorDict({
'state': torch.zeros(3, 3),
'player': torch.tensor([Player.black]),
'action_mask': torch.tensor([True] * 9)
}, batch_size=[])
return td.expand(batch_size).contiguous()
def _step(td):
td['player'] = - td['player']
N = td.shape[0]
batch_range = torch.arange(N)
action = td['action']
r, c = action // 3, action % 3
td['state'][batch_range, r, c] = td['player'].to(dtype=td['state'].dtype)
action_mask = (td['state'] == 0).reshape(N, 9)
rows_win = td['state'].sum(dim=1).abs() == 3
colums_win = td['state'].sum(dim=2).abs() == 3
diags_win = torch.diagonal(td['state'], dim1=1, dim2=2).sum().abs() == 3
rev_diags_win = torch.diagonal(td['state'].flip(-1), dim1=1, dim2=2).sum().abs() == 3
full = td['state'].abs().sum(dim=(-1, -2)) == 9
win = torch.any(rows_win | colums_win | diags_win | rev_diags_win, dim=-1, keepdim=True)
terminated = win | full
reward = torch.where(condition=win, input=td['player'], other=torch.zeros(N, 1))
# switch player
td['state'] = td['state']
out = TensorDict({
'state': td['state'],
'player': td['player'],
'action_mask': action_mask,
'terminated': terminated,
'reward': reward
}, batch_size=td.shape)
return out
def _make_spec(self, td_params):
f_type = td_params['state'].dtype
state = BoundedTensorSpec(
minimum=-1,
maximum=1,
shape=torch.Size(td_params['state'].shape),
dtype=f_type,
)
player = BoundedTensorSpec(minimum=-1, maximum=1, shape=(*td_params.shape, 1), dtype=torch.int64)
action_mask = BoundedTensorSpec(minimum=0, maximum=1, shape=(*td_params.shape, 9), dtype=torch.bool)
self.observation_spec = CompositeSpec(state=state, player=player, action_mask=action_mask, shape=td_params.shape)
self.action_spec = DiscreteTensorSpec(9, shape=(*td_params.shape, 1))
self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1), dtype=f_type)
def _set_seed(self, seed: Optional[int]):
rng = torch.manual_seed(seed)
self.rng = rng
def _reset(self, tensordict=None):
batch_size = tensordict.shape if tensordict is not None else self.batch_size
if tensordict is None or tensordict.is_empty():
tensordict = self.gen_params(batch_size).to(self.device)
if '_reset' in tensordict.keys():
reset_state = self.gen_params(batch_size).to(self.device)
reset_mask = tensordict['_reset'].squeeze(-1)
for key in reset_state.keys():
tensordict[key][reset_mask] = reset_state[key][reset_mask]
return tensordict
class TicTacToe(EnvBase):
batch_locked = False
gen_params = staticmethod(gen_params)
_make_spec = _make_spec
_reset = _reset
_step = staticmethod(_step)
_set_seed = _set_seed
def __init__(self, td_params=None, device="cpu", batch_size=None):
if batch_size is None:
batch_size = torch.Size([1])
elif isinstance(batch_size, int):
batch_size = torch.Size([batch_size])
elif isinstance(batch_size, Iterable):
batch_size = torch.Size(batch_size)
elif isinstance(batch_size, torch.Size):
pass
else:
assert False, "batch size must be torch.Size, list[int], or int"
if td_params is None:
td_params = self.gen_params(batch_size)
super().__init__(device=device, batch_size=batch_size)
self._make_spec(td_params)
@staticmethod
def player_perspective(state, player):
return state * player
class RGBTransform(ObservationTransform):
def __init__(self, out_key=None):
out_keys = ['pixels'] if out_key is None else [out_key]
super().__init__(in_keys=['state'], out_keys=out_keys)
self.colors = {
'white': torch.tensor([255, 255, 255], dtype=torch.uint8),
'black': torch.tensor([0, 0, 0], dtype=torch.uint8),
'grey': torch.tensor([128, 128, 128], dtype=torch.uint8)
}
def forward(self, tensordict):
return self._call(tensordict)
def _reset(self, tensordict, tensordict_reset):
return self._call(tensordict_reset)
def _call(self, td):
td['pixels'] = torch.zeros(*td.shape, 3, 3, 3, dtype=torch.uint8)
td['pixels'][td['state'] == 0] = self.colors['grey']
td['pixels'][td['state'] == 1] = self.colors['white']
td['pixels'][td['state'] == -1] = self.colors['black']
td['pixels'] = td['pixels']
return td
@_apply_to_composite
def transform_observation_spec(self, observation_spec):
N, H, W = observation_spec.shape
return BoundedTensorSpec(
minimum=0,
maximum=255,
shape=torch.Size((N, 3, H, W)),
dtype=torch.uint8,
device=observation_spec.device
)
if __name__ == '__main__':
env = TicTacToe(batch_size=1)
env = TransformedEnv(env)
env.append_transform(RGBTransform())
state = env.reset()
import matplotlib.pyplot as plt
# Set up the grid size
rows, cols = 3, 3
grid_size = 3 # Pixel size of the grid
# Create a figure and axis
fig, ax = plt.subplots()
ax.set_xlim(0, grid_size)
ax.set_ylim(0, grid_size)
ax.set_aspect('equal')
screen = ax.imshow(state['pixels'][0], extent=[0, 3, 0, 3], interpolation='nearest', origin='lower')
# Draw the initial grid
tile_size = grid_size / max(rows, cols)
# Define the click event handler
def on_click(event):
global state
if event.inaxes == ax:
action = int(event.ydata) * 3 + int(event.xdata)
if state['action_mask'][0, action]:
state['action'] = torch.tensor([[action]])
state = step_mdp(env.step(state), exclude_reward=False)
screen.set_data(state['pixels'][0])
fig.canvas.draw()
if state['terminated'][0]:
reward = state['reward'][0].item()
if reward == 0:
print('draw')
else:
print(f'{Player(int(reward)).name} wins')
state = env.reset()
else:
print("nope")
# Connect the click event to the event handler
fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()