Backward error on kl divergence

Hi,

I think you should rebuild the computation graph in each iteration, since when calling .backward() it will clear up the graph built before, see more details from this thread.

1 Like