When calculating a loss (with the mse_loss
function), should I call .detach()
on the target tensor if it has grad_fn on it? And does it even matter?
I depends on your use case and if you want to “train” the target tensor. Could you explain why the target has a valid grad_fn
, how it was created, and if you want to calculate gradient for this operation?
I am making a DQN (Deep Q Network), a reinforcement learning strategy, in which the network approximates a reward for actions in a given state. When training, the target for a single action is the reward from the environment + the predicted future reward ( target = reward + t.max(net.forward(next_state))
). This is why my target has an autograd because part of it is a prediction. So no, I don’t want the target to be trained, but I also don’t want to waste compute time calling .detach()
on everything, that’s why I’m asking if it matters.
I would assume the gradient calculation would be more expensive than calling .detach()
on a tensor.
In any case, you could also directly wrap the target computation into a with torch.no_grad()
guard to avoid creating the computation graph in the first place.
Thanks, helps a lot!