How to ensure dimensions much if states batch has different dimension from actions

I am trying to train a DQN to do optimal energy scheduling. Each state comes as a vector of 4 variables (represented by floats) saved in the replay memory as a state tensor, each action is an integer saved in the memory as a tensor too. I extract the batch of experiences as:

def extract_tensors(experiences):
    # Convert batch of Experiences to Experience of batches
    batch = Experience(*zip(*experiences))

    t1 =
    t2 =
    t3 = torch.stack(batch.reward)
    t4 =

    return (t1,t2,t3,t4)

I then unpacked them for purposes of updating the prediction and the target networks as:

experiences = memory.sample(batch_size)
states, actions, rewards, next_states = extract_tensors(experiences)

My Qvalues class for update looks like this:

class QValues():

   device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

   def get_current(policy_net, states, actions):
       return policy_net(states).gather(dim=1, index=actions)

   def get_next(target_net, next_states):                
       final_state_locations = next_states.flatten(start_dim=1) \
       non_final_state_locations = (final_state_locations == False)
       non_final_states = next_states[non_final_state_locations]
       batch_size = next_states.shape[0]
       values = torch.zeros(batch_size).to(QValues.device)
       values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach()

       return values

When I try running the training loop, I get the error below:

<ipython-input-8-4a79494b54ca> in <module>
--> 216                         current_q_values = QValues.get_current(policy_net, states, actions)
    217                         next_q_values = QValues.get_next(target_net, next_states)
    218                         target_q_values = (next_q_values * gamma) + rewards

<ipython-input-8-4a79494b54ca> in get_current(policy_net, states, actions)
    160     @staticmethod
    161     def get_current(policy_net, states, actions):
--> 162         return policy_net(states).gather(dim=1, index=actions)
    164     @staticmethod

RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at c:
1 Like

Got the somewhat the same error.Although in my case im getting a size mismatch when the network attempts to calculate the batched states. Did you manage a work around?

After doing abit of research i found a quick workaround by adding “actions = actions.squeeze()” in the get_current method as follows. Worked for me. Hope it helps!:


Hey, thank you for the reply. I have tried that but it is still returning the error.

<ipython-input-4-fbc914f54837> in <module>
--> 217                         current_q_values = QValues.get_current(policy_net, states, actions)
    218                         next_q_values = QValues.get_next(target_net, next_states)
    219                         target_q_values = (next_q_values * gamma) + rewards

<ipython-input-4-fbc914f54837> in get_current(policy_net, states, actions)
    161     def get_current(policy_net, states, actions):
    162         actions = actions.squeeze()
--> 163         return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1))
    165     @staticmethod

RuntimeError: Invalid index in gather at c:\a\w\1\s\tmp_conda_3.7_070024\conda\conda-bld\pytorch-cpu_1544079887239\work\aten\src\th\generic/THTensorEvenMoreMath.cpp:457```

What is the shape of your states? Depending on the dimensions you used for your network, the shape of the states batch should be torch.Size([batch_size, number of colour channels, width, height]) and each state should be torch.Size([dimension, number of colour channels, width, height]) (assuming you are using images for input) . Where dimension should match the dimensions set when creating the network. Dimension should be 1 to match with the “get current” method

Yes. You could use “state = state.squeeze().unsqueeze(dim=0)” to change the dimension of your state that will agree with the “get current” method. This should be done before pushing the state to memory

My states just tensors of 4 floats. My environment gives states as a list say [4, 20, 4, 32], so I convert them to tensors before I forward them.