Hi @SimonW, thanks for your help! I’ve just updated the optimizer:
loss_func = torch.nn.MSELoss(size_average=False, reduce=False)
And also coded the backward pass accordingly:
# Run backward pass
error = loss_func(q_phi, y)
error = torch.clamp(error, min=-1, max=1)**2
error = error.sum()
And it seems like no errors appear, which implies that the ‘backward’ operation is running correctly!
Will test it out in the Atari Environment and let you know how it goes. The code is in here in case anyone wants to check it out meanwhile.