Python crashes while integrating differential equation

I understand that requires_grad=True means history is traced, but even a repeated addition of a 5x3 tensor crashes python, as it runs out of memory. The code inside the loop represents an integration of the differential equation dx/dt=x. I’m using Pytorch 0.3.0 on Windows 10.

import torch
from torch.autograd import Variable

x = Variable(torch.FloatTensor(range(9)).view(3,3), requires_grad=True) 
time_step = 0.002
for i in range(10000):
    x = x + x*time_step    
y = torch.sum(x)  # just an arbitrary loss function (in fact none of the code from this line on is needed to recrate the crash)
y.backward()
print(x.grad)  

I do understand the advice laid out here but you can easily imagine the need for code that looks like above for systems of ODEs.

How to handle this? Is truncation the only way?

Isn’t this equivalent to:

x = x * (1 + timestep) ** 10000

? That would probably solve the memory issue because the computation graph being built would not be so large.

Yes, for this particular minimal working example. But my actual system of ODEs is much more complicated than that. How do we handle those?

I don’t know enough about ODEs to suggest specific optimizations right now, but the main problem (as you’ve noticed) is that performing the operations increases the size of the computation graph. Each addition/multiplication is an operation in the computation graph, and it sounds like the problem is that the computation graph is getting too large. Do you notice your memory usage steadily growing with each iteration of your for loop?

We’re working on new ways to trade compute for memory (see this pull request) that will probably help in this case.