Hello there,
I’m trying to implement (copypaste from different sources) the DDQN algorithm in PyTorch. I get the following error when calling loss.backward()
: RuntimeError: Function MmBackward returned an invalid gradient at index 1 - got [128, 16] but expected shape compatible with [4, 16]
.
I believe this is the relevant code:
transitions = self.memory.sample()
batch_state, batch_next_state, batch_action, batch_reward = zip(*transitions)
batch_state = torch.cat(batch_state)
batch_next_state = torch.cat(batch_next_state)
batch_action = torch.cat(batch_action).squeeze(1)
batch_reward = torch.cat(batch_reward)
# current Q values are estimated by NN for all actions
current_q_values = self.policy_net(batch_state).gather(0, batch_action)
# expected Q values are estimated from actions which gives maximum Q value
max_next_q_values, _ = self.target_net(batch_next_state).detach().max(0)
expected_q_values = batch_reward + (self.gamma * max_next_q_values)
# Compute loss
loss = self.criterion(current_q_values, expected_q_values)
# Zero the gradients and backpropagate
self.optimizer.zero_grad()
loss.backward()
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
# Optimizer
self.optimizer.step()
Both policy_net
and target_net
are just two linear layers with ReLU activation.
The error makes me believe that the graph is messed up somewhere, but I’m not sure how such a situation can be reached. I assume MmBackward refers to a torch.mm
computation, but I don’t think I ever call torch.mm.
I am incredible lost and any help would be much appreciated. Thanks in advance!