I am trying to train a DQN to do optimal energy scheduling. Each state comes as a vector of 4 variables (represented by floats) saved in the replay memory as a state tensor, each action is an integer saved in the memory as a tensor too. I extract the batch of experiences as:
def extract_tensors(experiences):
# Convert batch of Experiences to Experience of batches
batch = Experience(*zip(*experiences))
t1 = torch.cat(batch.state)
t2 = torch.cat(batch.action)
t3 = torch.stack(batch.reward)
t4 = torch.cat(batch.next_state)
return (t1,t2,t3,t4)
I then unpacked them for purposes of updating the prediction and the target networks as:
experiences = memory.sample(batch_size)
states, actions, rewards, next_states = extract_tensors(experiences)
My Qvalues class for update looks like this:
class QValues():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@staticmethod
def get_current(policy_net, states, actions):
return policy_net(states).gather(dim=1, index=actions)
@staticmethod
def get_next(target_net, next_states):
final_state_locations = next_states.flatten(start_dim=1) \
.max(dim=1)[0].eq(0).type(torch.bool)
non_final_state_locations = (final_state_locations == False)
non_final_states = next_states[non_final_state_locations]
batch_size = next_states.shape[0]
values = torch.zeros(batch_size).to(QValues.device)
values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach()
return values
When I try running the training loop, I get the error below:
<ipython-input-8-4a79494b54ca> in <module>
214
215
--> 216 current_q_values = QValues.get_current(policy_net, states, actions)
217 next_q_values = QValues.get_next(target_net, next_states)
218 target_q_values = (next_q_values * gamma) + rewards
<ipython-input-8-4a79494b54ca> in get_current(policy_net, states, actions)
160 @staticmethod
161 def get_current(policy_net, states, actions):
--> 162 return policy_net(states).gather(dim=1, index=actions)
163
164 @staticmethod
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at c: