So from my understanding for backward propagation, tensors record the derivative during the computation so that when we use Loss.backward() it can easily obtain the partial derivative of loss value with respect to the neural network parameters.
So below I have my code where I’m sampling a mini-batch of experience and updating a DQN using temporal difference and Bellman equation. For debugging purposes, I printed out the tensors. The numbers are all correct but for tensors Qpred (Q value of action selected in state) and Qptimal (calculated using Bellman equation), they get a grad_fn = CopySlices when printed out.
My concern is that the grad_fn = CopySlices means I made a copy of the tensor and the derivatives the tensor was storing are lost, hence causing issues when calling Loss.backward()
‘’’
def learn(self):
if (self.counter < self.batchsize):
return
states, actions, rewards, nextstates, dones, actionmasks = self.sample_experience() #sample minibatch
#Qnetwork
Qvalues = self.Qnetwork.forward(torch.FloatTensor(states))
Qpred = torch.FloatTensor(np.zeros(self.batchsize))
#Qtarget
Qvaluesnext = self.Qtarget.forward(torch.FloatTensor(nextstates))
Qoptimal = torch.FloatTensor(np.zeros(self.batchsize))
#debug
print('actionmask')
print(actionmasks)
print('Qvalues')
print(Qvalues)
print('Qvaluesnext')
print(Qvaluesnext)
for i in range(self.batchsize):
#Qpred
Qpred[i] = Qvalues[i][actions[i]]
#Qoptimal using Bellman Ford equation
if (dones[i] == False):
#mask invalid actions in Qvaluesnext
Qindex = -1 #index of largest valid Q of Qvaluesnext
for j in range(self.actionsize):
if (actionmasks[i][j] == 0): #if action is valid
if (Qindex == -1): #first valid action
Qindex = j
elif (Qvaluesnext[i][j] > Qvaluesnext[i][Qindex]): #found better valid action
Qindex = j
#debug
if (Qindex == -1): #error case
print('Error -> Qindex = -1')
return
#Qoptimal[i] = torch.max(Qvaluesnext[i])
Qoptimal[i] = rewards[i] + 0.99*Qvaluesnext[i][Qindex]
else:
Qoptimal[i] = rewards[i]
#debug
print('Qpred')
print(Qpred)
print('Qoptimal')
print(Qoptimal)
#update Qnetwork parameters
self.optim.zero_grad()
Loss = self.LossFn(Qpred,Qoptimal)
#debug
print(Loss)
Loss.backward()
self.optim.step()
‘’’