Memory efficient BPTT, request for review of pytorch implementation

Hi, and thanks for the amazing work on this platform, it’s really a joy to use.

I’ve implemented a version of memory efficient BPTT similar to what’s described as selective hidden state memorization (BPTT-HSM) in Gruslys 2016 Memory-Efficient Backpropagation Through Time. Code is here based on the code for example “word_language_model” of pytorch/examples.

On my Titan X pascal GPU using the PTB dataset, a BPTT of 512 takes 11.8 seconds per epoch and uses 2585 MiB. With the memory efficient approach and a step size of 256, it takes 13.1 seconds per epoch but uses 1543 MiB (pytorch 0.2.0.post3). With a base of about 460 Mib this is pretty close to 0.5 memory.

Briefly, the approach snips a desired backpropagation sequence length into smaller pieces (steps) and runs backpropagation on each piece, attaching input gradients to outputs (since we’re in reverse). The memory savings occurs because the maximum chain of Variable history needed is the smaller step value instead of the original full length. For example, a 512 sequence needed for an RNN can be run in two 256 sequence lengths for half the memory cost (with an extra 256 sequence forward step computational cost).

It seems to perform as expected, trading off reduced memory with extra forward propagation steps. However, I’d appreciate a glance here at some of the trickier (for me) spots to make sure I’m handling variables and gradients correctly and with best pytorch practice.

First step:

        hidden_v = repackage_hidden(hidden, volatile=True)
        data_v, _ = get_batch(train_data, i, evaluation=True)
        hsm = { -1 : repackage_hidden(hidden) }
        intervals = list(enumerate(range(0, data.size(0), args.bptt_step)))
        for f_i,f_v in intervals[:-1]:
            output,hidden_v = model(data_v[f_v:f_v+args.bptt_step], hidden_v)
            hsm[f_i] = repackage_hidden(hidden_v, volatile=False,

For memory efficiency, I use volatile=True variables for the forward step to capture just the hidden states needed (which occur at args.bptt_step intervals throughout the larger args.bptt sequence). Questions: can maximum memory efficiency be achieved with only the hidden state (hidden_v) being volatile here, rather than also the input data_v? Or both to be safe?

I assume the state value Tensors (.data) can be safely stored solely by repackaging the variable–creating a new variable from the old contents.

requires_grad=True ensures that each state will have a gradient after backprop that will be used to differentiate the output earlier in the sequence.

Second step, main algorithm:

        loss = 0
        for b_i, b_v in reversed(intervals):
            output,h = model(data[b_v:b_v+args.bptt_step], hsm[b_i-1])
            iloss = criterion(output.view(-1, ntokens),
            if b_v+args.bptt_step >= data.size(0):
                # 1. state for next training sequence
                hidden = h
                # 2. pass gradients saved on input states to output state 
                # variables along with output loss to multi-variable backprop
                for l in h:
                    g = save_grad.popleft()
                torch.autograd.backward(variables, grad_variables)
            if b_i > 0:
                # 3. store gradients initialized on input states
                save_grad = collections.deque()
                for l in hsm[b_i-1]:
            loss +=[0]
  1. Here, the first backprop fragment yields the next hidden state for the next training sequence and the incremental loss and gradient is done without any concern about gradient from the future.

  2. On every sequence except the first, the gradient saved from the prior sequence becomes the gradient to be associated with the given state variable output and passed to torch.autograd.backward().

  3. On each step sequence except the last, gradients on the input are stored in a deque for use in the next step sequence. Question: Is it safe to use the gradient variable directly for later use in torch.autograd.backward without clone?

(This implementation also turns off size_average of the CrossEntropyLoss criterion and does gradient averaging after all backprop is done.)

Thanks for any comments.

volatile is a viral flag. if one input is volatile, all inputs become volatile.

3: yes.

Looks good overall. Instead of repackage_*, you can use hidden.detach_(), but I am not sure if that frees the memory (it might be an outstanding bug).

This concept of trading off memory for compute in autodiff engines is called checkpointing.

In my opinion (i didn’t see the full code), you can clean this up and make it a more generic checkpointing function, that takes a model, data and number of max timesteps to checkpoint at.

@smth Thanks, much obliged! I’ll check on detach and see where we’re at. A generic checkpointing function also makes great sense, I’ll do that.