Backprop through graph for the second time error. Could someone help?


I’m using PyTorch for reinforcement learning and I have a question concerning the usage of torch.split. After each episode, I get a bunch of correspondances state/action/reward. Now, in the critic optimization, I want to split this dataset to minibatches. Here’s what I’m doing:

        # VALUE UPDATE 

	batch_size = 32 
	nb_batchs = int(len(self.v_mem)/batch_size)+1

	preds =  #v_mem holds all the state value prediction
	chunks = torch.split(preds, batch_size)

	for b in range(nb_batchs): 

                current_min = b*batch_size
		current_max = np.min([(b+1)*batch_size, len(self.v_mem)])

		selected = self.v_n_mem[current_min:current_max] #v_n_mem is a list. It holds predictions for next state value. It is a tensor, but without history (I used .detach()) 

		batch_next_estim_tensor = torch.tensor(selected).float()
		batch_target = retours[current_min:current_max,0].reshape(-1,1) + 0.99*batch_next_estim_tensor.reshape(-1,1)

		batch_pred = chunks[b]
		value_loss = F.mse_loss(batch_pred, batch_target.detach())


This loop runs fine for the first iteration. Afterwards, I end up with an error saying that I’m trying to backprop through the graph a second time. I don’t see why. Could someone help ?

Thanks !

What is this line doing?

Oops, I forgot this line. It just call the optimizer:

def maj(self, loss):

Thanks for pointing that out

Since it contains previous predictions you need to specify retain_graph=True in your backward call. Otherwise it will delete the autograd graph after/during the first call

Actually, I don’t see why I should specify retain_graph = True because I’m only using chunks of this tensor which are only used once. And, as I said, the next state value has no history, I use it as a label.

How are you creating that list exactly?

For each episode, after each action, when I receive the new state, the reward and termination boolean, I forward the new state. Basically, here’s what I do:

mask = 0. if done else 1.
new_state_value = value(new_state).detach()*mask