Loss.backward() crashes in RL


I’m facing an issue with a reinforcement learning script, that I wrote partially inspired by this tutorial: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html. Here the author has a function for selecting the action and another one for the backprop, which makes sense because in the middle we have to perform the selected action and see the reward. However, in this tutorial in the optimization function the forward pass is called again before doing the backprop instead of saving the performed action in the buffer. Yet, I cannot do it because I am using a Bayesian Neural Network for the exploration-exploitation policy, so if I do the forward pass again, the output could be very different in terms of the action selected.
It happens that if I save the predicted value in the buffer, when doing the loss.backward() it crashes unless setting retain_graph=True. In this latter case the training converges but it requires a lot of time. I don’t understand this behavior because I am not calling the function more than once, as other posts suggested. The error I get without setting retain_graph is this one:

RuntimeError: ntensor >= 3 ASSERT FAILED at /opt/conda/conda-bld/pytorch-cpu_1556653114183/work/build/aten/src/ATen/native/cpu/IndexKernel.cpp.AVX.cpp:51, please report a bug to PyTorch.

If instead I call the forward pass again on the entire minibatch it works, but the training is not effective, since it is performed computing the error on couples prediction-ground_truth which may not be correlated.

This is what I am doing, with Transition being a namedtuple like in the tutorial:

transitions = memory[node].sample(minibatch)
batch = Transition(*zip(*transitions))
reward_batch = torch.stack(batch.reward)
value_batch = torch.stack(batch.value)
value_batch = value_batch.view(value_batch.shape[0],)
loss = torch.nn.functional.binary_cross_entropy_with_logits(value_batch,reward_batch)