Issue with handling invalid moves in reinforcement learning

I am new to pyTorch and so far I love it compared to Tensorflow.

I run into an issue during a reinforcement learning exercise.
What I am trying to do is to use valid actions when computing the softmax.
For example, let’s say I have this logit from the model for 3 actions: (0.3, 0.4, 0.1)
but I don’t want to use the first item so I set it to 0: (0, 0.4, 0.1) and take the softmax of only (0.4, 0.1):
F.softmax(torch.tensor([ 0.4, 0.1]), dim=-1) = tensor([0.5744, 0.4256])
So I would eventually like to get to [0, 0.5744, 0.4256] and feed this to the system like below:

I had many version of filterValidActions() but it always throws this error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I know that I do in-place operation on action probs but I could not figure out how I can take softmax on valid actions and also retain the grad on it.
every time I have to make an assignment like this:
probs[:,:,2] = 0 when action is 0 and this creates the issue.

def select_my_action(state, model_hidden, prev_action):
    state = torch.FloatTensor(state) 
    state = state.view(1, 1, state.size(0)).to(device)
    probs, state_value, model_hidden = model(state, model_hidden)

    probs = filterValidActions(probs, prev_action)
    
    m = Categorical(probs)
    action = m.sample()
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.item(), model_hidden

def filterValidActions(probs, action):
    newProbs = probs.clone()
    if (action == 0): # 0 -> [0,1]        
        new2Probs = F.softmax(newProbs[:,:,0:2], dim=-1)
        newProbs[:,:,0:2] = new2Probs
        newProbs[:,:,2] = 0
    return newProbs

I looked at this link but I still could not figure out my issue… :frowning:

Do you have any solution to this issue?
Thanks a lot

You can change these three lines as below and try if it works:

only2probs = F.softmax(newProbs[:,:,0:2], dim=-1)
zeros = torch.zeros(only2probs.shape[0], only2probs.shape[1], 1)
newProbs = torch.cat((only2probs, zeros), dim=-1)

Thank You so much Arul. It did work!
It does retain the grad as well.