RNN cuda out of memory

I am training a naive RNN for 1D time series data and have the following training loop. Since I cannot hold the total batch x seq_len x entire_time_steps data, I am transfering batch x seq_len for every time step data as the code below. But the memory keep growing for every time step as follows. I tried the following approaches.

Approach 1: allocating only once outside the time loop as inp and tried to copy the batch x seq_len to inp. But got an error saying cannot copy from torch.floattensor to torch.cuda.floattensor.
Approach 2: Also tried cuda.empty_cache() for every 50 iterations and the memory still keeps growing.

For the 1D time series data, how to transfer to cuda for every timestep? What is a good way to execute this training loop if the entire batch x seq_len x entire_time_step is huge to fit in cuda memory.


Memory growth:
0 1048576
1 43122688
2 84017152
3 124911616
4 165806080
5 207749120
6 248643584
7 289538048
8 330432512
9 371326976
10 412221440
11 454164480
12 495058944
13 535953408
14 576847872
15 617742336
16 658636800
17 699531264

393 16136667136 <-- out of memory at this time step.

Code :

for epoch in range(1, NUM_EPOCHS + 1):
    rand_idxs = np.random.permutation(X_train.shape[0])
    batches = np.array_split(rand_idxs, NUM_BATCHES)
    print("Epoch: {:4}".format(epoch))    
    for j,batch in enumerate(batches,1): 
        current_batch_X = (X_train_torch[batch,:,:])                
        hidden = torch.zeros((NUM_LAYERS, current_batch_X.shape[0],HIDDEN_SIZE)).cuda()
        # over all time steps
        for i in range(0, current_batch_X.shape[2]):
            inp = current_batch_X[:,:,i].reshape((current_batch_X.shape[0],
            print(i, torch.cuda.max_memory_cached())
            _,hidden = rnn(inp, hidden)
            if i%50 == 0:
        current_batch_y = (y_train_torch[batch]).cuda()
        loss = criterion(output, current_batch_y)

torch.cuda.empty_cache() will not avoid the out of memory issue, but might instead just slow down your code, as PyTorch would need to reallocate the device memory.

I guess your memory usage grows, since you are storing the computation graphs for all time steps in memory before calling backward and thus freeing the intermediates.
You could either lower the number of steps or alternatively call backward() after each (or a couple of iterations) and detach() the hidden state in the following operation.
Let me know, if this would work for you. :slight_smile: