I have read over 10 different reports of a similar problem with exploding RAM on a CPU and none of them have worked.
I can run my forward pass using just the encoder network without any problems so I know my DataLoader and Encoder are fine. If I use the Decoder just for forward passes, without even storing or computing any losses, the RAM explodes.
Therefore, what is it about my decoder that causes this? I tried to follow this https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html official Pytorch tutorial with the Attention based decoder as much as possible. The only difference with my code is that my Encoder processes the whole minibatch first rather than a single example so I have a double for loop (loop through the batch elements and then each time step of them).
Here is my code to run the forward passes with a randomly created latent space (normally comes from the Encoder):
device = torch.device("cpu")
import torch.optim as optim
epochs = 5
learning_rate=0.01
# initialize the NN.
decoder_net = DecoderNet().to(device)
dataloader = DataLoader(df, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=pad_and_sort_batch)
mini_batch_iters=0
# iterate through epochs
for e in range(1,(epochs+1)):
# counter for the number of minibatches passed
mini_batch_iters=0
# randomly generated latent space, minibatch size of 32, length 200 latent space vector
latent = torch.randn((32,200))
# iterate through data
for x in enumerate(dataloader):
#finds the longest sequence in the batch
max_l = torch.max(x[1][1])
#creating variables to hold outputs:
batch_outs = torch.zeros([BATCH_SIZE,max_l, VOCAB_SIZE], device=device, requires_grad=True)
#iterate through the minibatch:
for length, b_ind in zip(x[1][1], range(BATCH_SIZE)):
#print('batch ind', b_ind)
#print('length of example', length)
#more holder variables including initializing hidden inputs to the LSTM:
seq_outs= torch.zeros([length, VOCAB_SIZE], device=device, requires_grad=True)
prev_out = torch.zeros([VOCAB_SIZE], device=device, requires_grad=False)
hidden = (torch.zeros([2,1,int(CODE_LAYER_SIZE/2)], device=device, requires_grad=False),
torch.zeros([2,1,int(CODE_LAYER_SIZE/2)], device=device, requires_grad=False))
hidden_out = (torch.zeros([1,1,VOCAB_SIZE], device=device, requires_grad=False),
torch.zeros([1,1,VOCAB_SIZE], device=device, requires_grad=False))
# For each element in the batch, I run it through 2 LSTMs at each timestep
for t in range(length):
prev_out, hidden, hidden_out = decoder_net(latent[b_ind,:], prev_out, hidden, hidden_out )
seq_outs[t,:]= prev_out
# add each sequence to the batch data collector
batch_outs[b_ind,:,:].add_( torch.cat( ( seq_outs, torch.zeros([ (max_l-length) ,VOCAB_SIZE] , device=device, requires_grad=False) ) , 0))
mini_batch_iters+=1
print('mini batches iters', mini_batch_iters)
After 30 minibatches my CPU RAM explodes to over 13GB (starting at less than 2GB) and linearly increasing.
Here is my simple DecoderNet:
BATCH_SIZE = 32
CODE_LAYER_SIZE=200
VOCAB_SIZE=21
class DecoderNet(nn.Module):
def __init__(self):
super(DecoderNet, self).__init__()
#decode
#self.dense1_predecode = nn.Linear(in_features=(CODE_LAYER_SIZE), out_features=50 )
self.decoder = nn.LSTM(input_size=(CODE_LAYER_SIZE+VOCAB_SIZE),bidirectional=False, hidden_size= int(CODE_LAYER_SIZE/2),num_layers=2)
self.decoder_out = nn.LSTM(input_size=int(CODE_LAYER_SIZE/2),bidirectional=False, hidden_size= VOCAB_SIZE,num_layers=1)
# may want to do bilinear times 2 as it is bidirectional
#self.dense1_dec = nn.Linear(in_features=(CODE_LAYER_SIZE*2), out_features=50 )
#self.dense2_dec = nn.Linear(in_features=50, out_features=20 )
def forward(self, latent_space, prev_out, hidden, hidden_out):
#decoding. takes in a single latent code from a single part of the batch.
# where input is the teacher forcing or the predictions from the previous step.
prev_out, hidden = self.decoder( torch.cat( (latent_space, prev_out ), 0 ).view(1, 1, -1), hidden)
prev_out, hidden_out = self.decoder_out(prev_out, hidden_out)
prev_out = F.log_softmax(prev_out, dim=2)
return prev_out.squeeze_(), hidden, hidden_out
I don’t get what is wrong. It must have to do with me storing the outputs from each timestep and minibatch entry? However, the PyTorch official tutorial https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html does something very similar.
UPDATE: While waiting for an answer I have run the Decoder as a whole batch without the for loops (this isn’t what I ultimately want) but it worked without any memory issues. Clearly there is something wrong with how I am using my for loops and storing intermediate computations.
Thank you in advance for any help!