Detach hidden state of model

Hello,
I’m working on model to resolve physical PDEs, but during training, the time is getting incresing from an apoch to an other contantly, I don’t understand why, the model is quiet unusual (not classical one) it’s built using pytorch-geometric more specifically ARMA layer,
I read in some discussions here on the forum that we can delete hidden states and that reduce memory usage, but I don’t know how to apply that in my case since my model is composed from two ARMA layers and forward function that implement PDEs.
thanks

Your described issue is often caused by an increasing computation graph, which would not only increase the computation time but also the memory usage.
This could happen, if you are keeping the computation graph alive by using backward(retain_graph=True), which won’t delete the computation graph after the backward pass.
Often this issue is caused by reusing the same hidden tensors without detaching them. While this could fit your use case, you would have to check “how far” into the past iterations the gradients should be computed.