Hello,
I’ve been struggling with the same issue for the last couple days.
Here’s my neural network :
# Neural network from Deep CFR paper
class CardEmbedding(nn.Module):
def __init__(self, dim):
super(CardEmbedding, self).__init__()
self.rank = nn.Embedding(13, dim)
self.suit = nn.Embedding(4, dim)
self.card = nn.Embedding(52, dim)
def forward(self, input):
B, num_cards = input.shape
x = input.view(-1)
valid = x.ge(0).float() # -1 means 'no card'
x = x.clamp(min=0)
embs = self.card(x) + self.rank(x // 4) + self.suit(x % 4)
embs = embs * valid.unsqueeze(1) # zero out 'no card' embeddings
return embs.view(B, num_cards, -1).sum(1)
class DeepCFRModel(nn.Module):
def __init__(self, n_card_types=7, n_bets=20, max_bet=100000, n_actions=4, dim=256):
super(DeepCFRModel, self).__init__()
self.max_bet = max_bet
self.card_embeddings = nn.ModuleList(
[CardEmbedding(dim) for _ in range(n_card_types)]
)
self.card1 = nn.Linear(dim * n_card_types, dim)
self.card2 = nn.Linear(dim, dim)
self.card3 = nn.Linear(dim, dim)
self.card4 = nn.Linear(dim, dim)
self.card5 = nn.Linear(dim, dim)
self.bet1 = nn.Linear(2 * n_bets, dim)
self.bet2 = nn.Linear(dim, dim)
self.bet3 = nn.Linear(dim, dim)
self.comb1 = nn.Linear(2 * dim, dim)
self.comb2 = nn.Linear(dim, dim)
self.comb3 = nn.Linear(dim, dim)
self.action_head = nn.Linear(dim, n_actions)
self.init_weights()
def forward(self, cards, bets):
# 1. card branch
# embed hole, flop, and optionally turn and river
card_embs = []
for embedding, card_group in zip(self.card_embeddings, cards):
if card_group.numel():
card_embs.append(embedding(card_group.view(-1, 1)))
card_embs = torch.cat(card_embs, dim=1)
x = F.relu(self.card1(card_embs))
x1 = F.relu(self.card2(x))
x2 = F.relu(self.card3(x1))
x3 = F.relu(self.card4(x2))
x4 = F.relu(self.card5(x3))
bet_size = bets.clamp(0, self.max_bet)
bet_occurred = bets.ge(0)
bet_feats = torch.cat([bet_size, bet_occurred.float()], dim=1)
y = F.relu(self.bet1(bet_feats))
y1 = F.relu(self.bet2(y) + y)
y2 = F.relu(self.bet3(y1) + y1)
z = torch.cat([x4, y2], dim=1)
z1 = F.relu(self.comb1(z))
z2 = F.relu(self.comb2(z1) + z1)
z3 = F.relu(self.comb3(z2) + z2)
z4 = F.normalize(z3, p=2, dim=1) # (z - mean) / std
return self.action_head(z4)
def init_weights(self):
for param in self.parameters():
nn.init.zeros_(param)
def init_weights_randomly(self):
for param in self.parameters():
nn.init.normal_(param, mean=0, std=0.1)
Then I create three instances of this neural network :
class DeepCFR:
def __init__(
self,
starting_stacks: tuple[int],
blinds: tuple[int, int],
ranges: tuple[tuple[tuple[str, str]], tuple[tuple[str, str]]],
pot: int,
board: tuple[str],
bets: tuple[tuple[int], tuple[int]] = [(0.3, 0.7), (0.3, 0.7)],
buyin: int = 100,
iterations: int = 1,
K: int = 30,
n_players: int = 2,
mv_max: int = 100000,
mpi_max: int = 200000,
):
self.iterations = iterations
self.n_players = n_players
self.K = K
self.starting_stacks = starting_stacks
self.blinds = blinds
self.ranges = ranges
self.pot = pot
self.board = board
self.bets = bets
self.buyin = buyin
self.value_raise = 0.1
self.max_size = 20
self.m_v = [[], []]
self.m_pi = []
self.mv_max = mv_max
self.mpi_max = mpi_max
self.val_nets = [DeepCFRModel(), DeepCFRModel()]
self.val_net_optims = [
optim.Adam(self.val_nets[0].parameters(), lr=LR),
optim.Adam(self.val_nets[1].parameters(), lr=LR),
]
self.strategynet = DeepCFRModel()
self.strategynet_optim = optim.Adam(self.strategynet.parameters(), lr=LR)
self.n_epochs = 1
self.batch_size = 512
def train_advantage_network(self, player):
print(f"Training advantage network of player {player}...")
criterion = torch.nn.MSELoss()
running_loss = 0.0
batch_loss = []
self.val_nets[player].init_weights_randomly()
for _ in range(self.n_epochs):
for batch in range(0, len(self.m_v[player])):
self.val_net_optims[player].zero_grad()
batch_data = self.m_v[player][batch]
cards = batch_data["cards"]
bet_history = batch_data["bet_history"]
regrets = batch_data["regrets"]
(
print("bet_history", bet_history.size())
if bet_history.size()[1] > 20
else None
)
outputs = self.val_nets[player](cards, bet_history)
loss = criterion(outputs, regrets)
batch_loss.append(loss.float())
if batch % self.batch_size == 0:
batch_loss_mean = torch.stack(batch_loss).mean()
print(f"Batch loss: {batch_loss_mean}")
batch_loss_mean.backward()
self.val_net_optims[player].step()
batch_loss = []
running_loss += loss.item()
print(f"Loss: {running_loss/len(self.m_v[player])}")
self.trained = True
def train_strategy_network(self):
criteron = torch.nn.MSELoss()
running_loss = 0.0
batch_loss = []
self.strategynet.init_weights_randomly()
for _ in range(self.n_epochs):
for batch in range(0, len(self.m_pi)):
self.strategynet_optim.zero_grad()
batch_data = self.m_pi[batch]
cards = batch_data["cards"]
bet_history = batch_data["bet_history"]
pred = batch_data["pred"]
outputs = self.strategynet(cards, bet_history)
loss = criteron(outputs, pred)
batch_loss.append(loss.float())
if batch % self.batch_size == 0:
batch_loss_mean = torch.stack(batch_loss).mean()
print(f"Batch loss: {batch_loss_mean}")
batch_loss_mean.backward()
self.strategynet_optim.step()
batch_loss = []
running_loss += loss.item()
print(f"Loss: {running_loss/len(self.m_pi)}")
The training of my first two neural networks work perfectly well (the two that are in self.val_nets = [DeepCFRModel(), DeepCFRModel()]). However I can not train self.strategynet. I always have the same error :
File "C:\Users\lucas\Desktop\Poker\Poker-Texas-Holdem-Solver\python-src\main.py", line 29, in <module>
main()
File "C:\Users\lucas\Desktop\Poker\Poker-Texas-Holdem-Solver\python-src\main.py", line 26, in main
DeepCFRAlgo.run()
File "C:\Users\lucas\Desktop\Poker\Poker-Texas-Holdem-Solver\python-src\DeepCFR\deepcfr.py", line 201, in run
self.train_strategy_network()
File "C:\Users\lucas\Desktop\Poker\Poker-Texas-Holdem-Solver\python-src\DeepCFR\deepcfr.py", line 573, in train_strategy_network
batch_loss_mean.backward()
File "C:\Users\lucas\Desktop\Poker\Poker-Texas-Holdem-Solver\PokerSolverEnv\Lib\site-packages\torch\_tensor.py", line 522, in backward
torch.autograd.backward(
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256, 4]], which is output 0 of AsStridedBackward0, is at version 4; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I did :
torch.autograd.set_detect_anomaly(True)
But I do not have any more information.
I tried to replace as much inplace operations as possible. I my operations in forward.
I originally had operations like :
y = F.relu(self.bet1(bet_feats))
y = F.relu(self.bet2(y) + y)
y = F.relu(self.bet3(y) + y)
that I changed to have :
y = F.relu(self.bet1(bet_feats))
y1 = F.relu(self.bet2(y) + y)
y2 = F.relu(self.bet3(y1) + y1)
But there’s something else wrong. I can’t figure out though.
Also, I have already uninstalled and installed pytorch. I’m using pytorch 2.2.2 with CUDA.