Optimizing CUDA memory pipeline for RNN

I’m working with RNNs for medium-sized data (fits on a single machine, probably won’t need multiple GPUs).

But I want to get the most performance out of my RNN with the GPU I have, so I’ve been testing with even smaller datasets to make sure I understand the principles behind moving memory around with pytorch. I want to check my understanding to see what I’m missing. I have experience with graphics programming (OpenGL shaders, frame/vertex/pixel buffers and streaming, etc) but not really with CUDA.

My GPU memory is going to be divided mainly between the model parameters and the batch. Let’s say I want my model to be big, so I use about 70% of the GPU memory for the model and optimizer parameters. Then the batch might take 10% of the memory.

While processing one batch I do this:

  1. Access the data (stored in RAM as a torch.tensor). Each batch has size (seq_length+1, batch_size).
  2. Move the tensor to GPU, converting it to a torch.cuda.tensor with batch_tensor = batch_tensor.cuda()
  3. Create two variables from sliced chunks of the tensor: Variable(batch_tensor[:-1]) and Variable(batch_tensor[1:])
  4. Reset the model, run prediction, loss, backprop, and optimization.

Questions:

  1. If I can fit two batches in memory at the same time, can I be swapping one out while the other is still processing? How can I do this?
  2. Is the slicing operation slow/doing copying, or is it fast/more like reshaping?
  3. When does a torch.cuda.tensor get unallocated from the GPU? It looks like running a training loop steadily grows my GPU memory usage with every cuda() call, and it’s never freed up. Is there a way to generate an indefinitely long sequence with a fixed amount of GPU memory?
  4. Are there any other optimizations I can make?

For completeness, here is a test notebook on a small dataset (char-rnn/shakespeare) showing what I’ve got so far. I would also really appreciate any other general feedback about the code.

  1. Depends on what you mean by swapping out. You should be able to do fast asynchronous host2device and device2host copies if you use a background stream (torch.cuda.Stream) and a pinned buffer for staging. But if I were you I’d make sure it’s really worth it. You probably don’t want to spend too much time debugging synchronization issues just to save a few ms.
  2. Regular slicing is very fast. Returned tensor shares memory with the original one, only has some metadata modified.
  3. The memory is ready to be reused as soon as the tensor goes out of scope in Python (unless you have reference cycles in your objects). The GPU memory is only growing because we’re caching CUDA allocations. cudaMalloc and cudaFree are slow and synchronize CPU with the device, so it’s unacceptable to call them all the time. Our CUDA allocator will keep all the allocated blocks around, because it’ll assume that you’re going to use them at the next iteration too. However, if you’re going to operate on inputs of different lengths, you might want to do a warm up pass with the largest length you’re going to use throught the training. This will ensure that the allocated buffers will be able to contain any of the sequences. Otherwise the allocator has to do some compacting and allocate larger blocks once you use a sequence that doesn’t fit into the blocks you have allocated.

Thanks so much Adam! 1+2 answers my questions completely. Regarding 3 I realized I asked two questions in one.

The automatic management (and trying to avoid cudaMalloc and cudaFree) makes sense. With my training loop the GPU memory does grow, but eventually stops once the data has been exhausted. That matches what you describe.

What I don’t understand is that when I’m sampling an RNN, the memory grows without stopping. So the following code will work for (for example) length=100 but will erratically quit around length=200 in my case:

def evaluate(model, seeds, length=100):
    seeds = torch.from_numpy(seeds).cuda()
    hidden = model.create_hidden(batch_size=1).cuda()
    hidden = Variable(hidden)

    # prime model with all but last element
    for seed in seeds[-2:]:
        x = Variable(seed.view(1, 1, -1))
        _, hidden = model(x, hidden)

    # run final element and continue
    x = Variable(seeds[-1])
    predicted = []
    for i in range(length):
        x, hidden = model(x.view(1, 1, -1), hidden)
        predicted.append(x.data.cpu().numpy().reshape(-1))

    return np.asarray(predicted)

The GPU memory is released when the function returns, which makes me think that I’m doing something wrong with the way I use x or hidden. Is this my mistake, or does it sound more like a bug? If it sounds like a bug I can try to produce a minimal complete chunk of code that reproduces it.

I think in that last for you’re building a new full Variable graph through the model for every iteration, none of which actually get freed, since they’re used for the next computation. volatile=True in your Variable constructors should help.

Note that you only need to make an input or hidden volatile (it will propagate through the graph with a very high precedence).

Hi,

I don’t know if this is the right place, and the right moment, to post this, in case not I apologise and kindly ask to redirect me to another post, thank you.

I’m facing also a problem of memory consumption when training a model on a single GPU.

The dataset I use is quite small with respect to most of the dataset used nowadays: roughly 13000 sentences for training, 1250 for development and 3000 for testing.
I think it is not important to know what the task is, just I have sentences and I have to tag them with some labels, one label per token (basically like in POS tagging for those who are familiar).
I’m padding sentences with ‘__’ (3 underscores) so that they have all the same length of 96 tokens, and I split the training data into batches of 25 sentences each.
Since I want to use also character features, I also pad tokens with '
’ so that all tokens have the same length of 17.
Finally, since I’m testing a variant of a Jordan RNN where also predicted labels are re-embedded and given as input to the network, I also use batches of labels as input to the network.

So, to summarise what the input to the network is:

  1. the batch of tokens 96 x 25 indices (LongTensor)
  2. the batch of labels 96 x 25 indices (LongTensor)
  3. the batch of character sequences 96 x 25 x 17 indices (LongTensor)

Now the model:
I use character embeddings of size 50, an LSTM over characters with an output size of 100, token and label embeddings of size 200.
The hidden layer is another bidirectional-LSTM, which takes thus as input token and label embeddings, and the status output by the character-level LSTM. The output size of the second LSTM is 200 as well.
Finally I use a linear layer to map the hidden layer into the label space and I compute a log-softmax on that.

Since I re-inject predicted labels at each position in a sequence as input to the network, I perform the bi-LSTM forward and backward passes by hand (I checked, computations are correct, just it is slower than letting LSTM performs forward and backward over the whole sequence in one shot).

So, why all these details ?
Because I made few simple computations, and I found out that, assuming a double takes 10 bytes, the model should take roughly 9.5 MB of memory, while the input to the network (tokens + characters + labels) should take roughly 54 MB.

Instead, when I run the training on the GPU, the process takes 6.5GB of memory.
I know there a couple of more buffers that should be taken into account for gradients, states etc. but the difference here is more than 2 orders of magnitude.

If I run the process on CPU, the memory consumption is 1GB in total, so this accounts for the model, the whole dataset and anything else.

I’ve not a lot of experience with pytorch, does anybody can tell me if the memory consumption on GPU is normal given the model and the data I use ? In case I can give more details…

Any help would be more than appreciated
Thank you in advance

Hi Marco,

It’ll be easier to say if you can post some code. As you can see, it’s pretty easy to accidentally leak memory by building a autograd graph that is larger than what you actually need; e.g., if you are not explicitly detaching your hidden state when you move on to the next sentence.

Hi,

Thank you for your answer.
I will follow your suggestion, tomorrow or in the next few days I will put here peaces of the code to make things clear.
Hope you will still be available to take a look.

What you mean by

if you are not explicitly detaching your hidden state when you move on to the next sentence

???

Thank you again

Hi,

Thank you for your answer.
Following your suggestion, here are parts of the code in order to make things more clear.

So, once I have read the data, I create batches with my own function create_batches:

batch_training_data = create_batches(nn_params, training_data, word_to_ix, tag_to_ix, char_to_ix)
batch_dev_data = create_batches(nn_params, dev_data, word_to_ix, tag_to_ix, char_to_ix)

batch_training_data and batch_dev_data are lists.
word_to_ix, tag_to_ix and char_to_ix are maps from string to indices, just like those in the tutorials in the pytorch website.
Each element of the list is a tuple containing 3 LongTensors:

in_batch = torch.LongTensor(nn_params.max_sentence_length, nn_params.batch_size)
out_batch = torch.LongTensor(nn_params.max_sentence_length, nn_params.batch_size)
char_batch = torch.LongTensor(nn_params.max_sentence_length, nn_params.max_token_length, nn_params.batch_size)

in_batch contains tokens, out_batch labels and char_batch characters of each tokens.
With data I’m using, nn_params.max_sentence_length is 96, nn_params.batch_size is 25 and nn_params.max_token_length 17, these are dimensions I mentioned in my first post.

Now, the model is made of:

self.word_embeddings = nn.Embedding(self.vocab_size, nn_params.embedding_dim)
self.char_embeddings = nn.Embedding(self.char_vocab_size, nn_params.char_embed_dim)
self.label_embeddings = nn.Embedding(self.tagset_size, nn_params.label_embed_dim)
self.embed_dropout = nn.Dropout(p=0.5)
        
self.charRNN = nn.LSTM(nn_params.char_embed_dim, nn_params.char_hidden_dim, bidirectional=nn_params.bilstm)				# character-level LSTM
rnn_input_size = nn_params.embedding_dim + nn_params.label_embed_dim + self.num_directions * nn_params.char_hidden_dim	
self.RNN = nn.LSTM(rnn_input_size, nn_params.hidden_dim, bidirectional=nn_params.bilstm)									# global model hidden layer
self.hidden_dropout = nn.Dropout(p=0.5)

self.hidden2tag = nn.Linear(self.num_directions * nn_params.hidden_dim, self.tagset_size)									# mapping from hidden to tag space
self.aux_hidden2tag = nn.Linear(nn_params.hidden_dim, self.tagset_size)													# auxiliary mapping for forward and backward steps

# Main hidden layer state (for self.RNN)
self.hidden = (autograd.Variable(torch.zeros(self.num_directions, self.batch_size, self.hidden_dim).type(dtype), requires_grad=False, volatile=(self.TEST==1)),
                       autograd.Variable(torch.zeros(self.num_directions, self.batch_size, self.hidden_dim).type(dtype), requires_grad=False, volatile=(self.TEST==1)))

# Character-level hidden layer state (for self.charRNN)
self.char_hidden = (autograd.Variable(torch.zeros(self.num_directions, self.batch_size, self.char_hidden_dim).type(dtype), requires_grad=False, volatile=(self.TEST==1)),
				autograd.Variable(torch.zeros(self.num_directions, self.batch_size, self.char_hidden_dim).type(dtype), requires_grad=False, volatile=(self.TEST==1)))

Basically, since I have to re-inject previous predicted labels as input to the model, I run forward and backward steps of the LSTM « by hand ». I know this is not the best choice, as it will be slower, and also I do it by running twice a bidirectional LSTM, but I want to try this to compare with my previous models coded in octave, which were running faster and giving better results (so far…).
In the following:

  1. dtype is torch.FloatTensor or torch.cuda.FloatTensor depending if coda is available or not.
  2. num_directions depends on a flag telling if LSTMs are bidirectional or not
  3. There are a couple of intermediate variables needed to store character-level representations (char_rep), bidirectional hidden states (hidden_state), forward pass output (fw_scores), backward pass output (bw_scores), and bidirectional output (output)

The forward function of the model looks like:

vflag = (self.TEST==1)
(sequence_length, token_length, batch_size) = char_sequence.size()

# 1. Character-level LSTM for computing character-level representations, which are saved in _char_rep_
char_rep = autograd.Variable(torch.zeros(sequence_length, batch_size, self.num_directions * self.char_hidden_dim).type(dtype), volatile=vflag)
for i in range(sequence_length):
            self.char_hidden = (autograd.Variable(torch.zeros(self.num_directions, self.batch_size, self.char_hidden_dim).type(dtype), requires_grad=False, volatile=vflag),
                                autograd.Variable(torch.zeros(self.num_directions, self.batch_size, self.char_hidden_dim).type(dtype), requires_grad=False, volatile=vflag))
                
            char_embeds = self.char_embeddings( char_sequence[i,:,:] )
            lstm_out, self.char_hidden = self.charRNN(char_embeds, self.char_hidden)

            char_rep[i, :, 0:self.char_hidden_dim] = self.char_hidden[0][0, :, :]
            char_rep[i, :, self.char_hidden_dim:2*self.char_hidden_dim] = self.char_hidden[0][1, :, :]
        
word_embeds = self.embed_dropout( self.word_embeddings(sentence)  )
old_hidden = self.hidden     # save hidden state to reset it after forward step

# 2. Backward pass
bw_scores = autograd.Variable(torch.zeros(sentence_length, batch_size, self.tagset_size).type(dtype), requires_grad=False, volatile=vflag)
hidden_state = autograd.Variable(torch.zeros(sentence_length, batch_size, self.num_directions * self.hidden_dim).type(dtype), requires_grad=False, volatile=vflag)
prev_labels = labels[-1,:]
for i in range(sentence_length-1,-1,-1):
            label_embeds = self.embed_dropout( self.label_embeddings(prev_labels)  )
            total_input = torch.cat( [word_embeds[i,:,:].view(1, batch_size, -1), char_rep[i,:,:].view(1, batch_size, -1), label_embeds.view(1, batch_size, -1)], 2 )
            lstm_out, self.hidden = self.RNN(total_input, self.hidden)
            hidden_state[i,:,self.hidden_dim:2*self.hidden_dim] = lstm_out[:,:,self.hidden_dim:2*self.hidden_dim]
            bw_output = F.log_softmax( self.aux_hidden2tag( self.hidden_dropout(hidden_state[i,:,self.hidden_dim:2*self.hidden_dim]) ) )
            (max_scores, max_indeces) = torch.max(bw_output, 1)
            bw_scores[i,:,:] = bw_output
            prev_labels = max_indeces

# 3. Forward pass
fw_scores = autograd.Variable(torch.zeros(sentence_length, batch_size, self.tagset_size).type(dtype), requires_grad=False, volatile=vflag)
# We re-initialize the hidden state at its orginal value to compare equal to the bi-lstm performed in one shot by the PyTorch LSTM components
self.hidden = old_hidden
prev_labels = labels[0,:]
for i in range(len(sentence)):
            label_embeds = self.embed_dropout( self.label_embeddings(prev_labels)  )
            total_input = torch.cat( [word_embeds[i,:,:].view(1, batch_size, -1), char_rep[i,:,:].view(1, batch_size, -1), label_embeds.view(1, batch_size, -1)], 2 )
            lstm_out, self.hidden = self.RNN(total_input, self.hidden)
            hidden_state[i,:,0:self.hidden_dim] = lstm_out[:,:,0:self.hidden_dim]
            fw_output = F.log_softmax( self.aux_hidden2tag( self.hidden_dropout(hidden_state[i,:,0:self.hidden_dim]) ) )
            (max_scores, max_indeces) = torch.max(fw_output, 1)

            fw_scores[i,:,:] = fw_output
            prev_labels = max_indeces

# 4. Computes bidirectional label predictions
output = autograd.Variable(torch.zeros(sentence_length, batch_size, self.tagset_size).type(dtype), volatile=vflag)
for i in range(len(sentence)):
            output[i,:,:] = F.log_softmax( self.hidden2tag( self.hidden_dropout(hidden_state[i,:,:].view(batch_size,-1)) ) )

return output

I use 2 optimisers, one for the whole model and one for the auxiliary model, used to compute forward and backward predictions only:

optimizer = optim.Adadelta( model.parameters() )
aux_optimizer = optim.Adadelta( model.aux_hidden2tag.parameters() )

I know also this is not the best choice as the auxiliary optimiser is computing many gradients that will not be used, however…
The main train loop looks like this:

for epoch in range(args.epochs):	# args.epochs is the number of epochs passed in from command line
	train_loss = 0
	model.TEST = 0	# Flag used to know if we are in training or testing mode, so that to make variables volatile

	for i in torch.randperm( len(batch_training_data) ):	# looping over shuffled training batches
		model.init_hidden()							# re-init the hidden state for each batch

		nn_input = prepare_batch(batch_training_data[i][0], model.CUDA, model.TEST)
		nn_output = prepare_batch(batch_training_data[i][1], model.CUDA, model.TEST)
		input = [nn_input]
        if nn_params.char_features:					# flag telling if we use character features
            		input.append( prepare_batch(batch_training_data[i][2], model.CUDA, model.TEST) )
        if nn_params.label_features:					# flag telling if we use label features
            		input.append( nn_output )

		output = model(input)						# output contains bidirectional, forward and backward scores
		tag_scores = output[0]						# bidirectional scores

		if nn_params.label_features:
			fw_scores = output[1]					# forward scores, created by the « auxiliary module » only when using label features
			bw_scores = output[2]					# backward scores, created by the « auxiliary module » only when using label features

		    fw_loss = loss_function(fw_scores.view(sentence_length * batch_size, -1), nn_output.view(sentence_length*batch_size))
            aux_optimizer.zero_grad()				# optimiser only for the « auxiliary module »
            fw_loss.backward(retain_variables=True)	# parts of the graph are shared with the bidirectional model, so I keep the graph
            aux_optimizer.step()
            bw_loss = loss_function(bw_scores.view(sentence_length * batch_size, -1), nn_output.view(sentence_length*batch_size))
            aux_optimizer.zero_grad()
            bw_loss.backward(retain_variables=True)
            aux_optimizer.step()
        
         loss = loss_function(tag_scores.view(sentence_length * batch_size, -1), nn_output.view(sentence_length*batch_size))
        train_loss = train_loss + loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

I know this is a lot of code and a lot of details, so, if you, or anyone else, take the time to take a look that would be very kind. Please don’t hesitate to ask for clarifications if you don’t understand anything.

Thank you so much in advance