Hello,
I’m using PyTorch for reinforcement learning and I have a question concerning the usage of torch.split. After each episode, I get a bunch of correspondances state/action/reward. Now, in the critic optimization, I want to split this dataset to minibatches. Here’s what I’m doing:
# VALUE UPDATE
batch_size = 32
nb_batchs = int(len(self.v_mem)/batch_size)+1
preds = torch.cat(self.v_mem) #v_mem holds all the state value prediction
chunks = torch.split(preds, batch_size)
for b in range(nb_batchs):
current_min = b*batch_size
current_max = np.min([(b+1)*batch_size, len(self.v_mem)])
selected = self.v_n_mem[current_min:current_max] #v_n_mem is a list. It holds predictions for next state value. It is a tensor, but without history (I used .detach())
batch_next_estim_tensor = torch.tensor(selected).float()
batch_target = retours[current_min:current_max,0].reshape(-1,1) + 0.99*batch_next_estim_tensor.reshape(-1,1)
batch_pred = chunks[b]
value_loss = F.mse_loss(batch_pred, batch_target.detach())
self.value.maj(value_loss)
This loop runs fine for the first iteration. Afterwards, I end up with an error saying that I’m trying to backprop through the graph a second time. I don’t see why. Could someone help ?
Thanks !