I’m actually implementing a complex custom model which look like :
for step in range(18): "data calculation with big matrixes" x = lstmcell(data) "data calculation with x to get a new_x" data = new_x
My problem is I got my gpu (and cpu btw) out of memory because the gradient is calculated over the data calculation 18 times so at the end I have a very big graph stocked on my gpu (memory usage over 20 GB). Each iteration needs the data of the previous one.
So do you guys know some tricks to help me to reduce the memory usage with this kind of model structure, I tried to torch.no_grad the parts where I do calculation with the data and that works but my network loss function doesn’t reduce and I think its because of the no_grad block. Do I really need to compute the gradient of the calculation part ?