I’m using PyTorch for reinforcement learning and I have a question concerning the usage of torch.split. After each episode, I get a bunch of correspondances state/action/reward. Now, in the critic optimization, I want to split this dataset to minibatches. Here’s what I’m doing:
# VALUE UPDATE batch_size = 32 nb_batchs = int(len(self.v_mem)/batch_size)+1 preds = torch.cat(self.v_mem) #v_mem holds all the state value prediction chunks = torch.split(preds, batch_size) for b in range(nb_batchs): current_min = b*batch_size current_max = np.min([(b+1)*batch_size, len(self.v_mem)]) selected = self.v_n_mem[current_min:current_max] #v_n_mem is a list. It holds predictions for next state value. It is a tensor, but without history (I used .detach()) batch_next_estim_tensor = torch.tensor(selected).float() batch_target = retours[current_min:current_max,0].reshape(-1,1) + 0.99*batch_next_estim_tensor.reshape(-1,1) batch_pred = chunks[b] value_loss = F.mse_loss(batch_pred, batch_target.detach()) self.value.maj(value_loss)
This loop runs fine for the first iteration. Afterwards, I end up with an error saying that I’m trying to backprop through the graph a second time. I don’t see why. Could someone help ?