Why torch.cuda.allocated_memory reports that GPU Memory usage was decreasing during forwarding?

Hi, I’m trying to record the CUDA GPU memory usage using the API torch.cuda.memory_allocated. The target I want to achieve is that I want to draw a diagram of GPU memory usage(in MB) during forwarding.

This is the nn.Module class I’m using that makes use of the class method register_forward_hook of nn.Module to get the memory usage before the forward method being called:

class Segment(torch.nn.Module):
    def __init__(self,
                 specs_truncated,
                 in_channels,
                 use_checkpoint=False,
                 use_batch_norm=True,
                 ):
        super().__init__()
        self.subspecs = specs_truncated
        self.in_channels = in_channels
        self.use_checkpoint = use_checkpoint
        self.use_batch_norm = use_batch_norm
        self.seq_of_nodes = self._make_sequential_by_specs_truncated()
        self.handles_forward_hooks = self._register_forward_hooks(
            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU, torch.nn.MaxPool2d],
            lambda d: lambda m, i:
                print("Memory allocated before execute {}-{}: {} MB".format(
                    d['name'], d['class'],
                    torch.cuda.memory_allocated()/1e6)),
            lambda d: lambda m, i, o:
                print("Memory allocated after execute {}-{}: {} MB".format(
                    d['name'], d['class'],
                    torch.cuda.memory_allocated()/1e6)),
        )

    def forward(self, x):
        if self.use_checkpoint:
            return checkpoint(self.seq_of_nodes, x, use_reentrant=False)
        return self.seq_of_nodes(x)

    def _make_sequential_by_specs_truncated(self):
        layers = []
        in_channels = self.in_channels
        for node_spec in self.subspecs:
            if node_spec == 'M':
                layers += [
                        torch.nn.MaxPool2d(kernel_size=2),
                        ]
            else:
                layers += [
                        torch.nn.Conv2d(in_channels, node_spec, kernel_size=3, padding='same'),
                        ]
                layers += [
                        torch.nn.BatchNorm2d(node_spec),
                        torch.nn.ReLU(),
                        ][not self.use_batch_norm:]
                in_channels = node_spec

        return torch.nn.Sequential(*layers)

    def _register_forward_hooks(self, list_targets, pre_hook, post_hook):
        handles = {}
        for name, module in self.seq_of_nodes.named_modules():
            if isinstance(module, tuple(list_targets)):
                # handles[name + '.pre'] = module.register_forward_pre_hook(
                #    pre_hook({ 'name': name, 'class': module.__class__.__name__ }))
                handles[name + '.post'] = module.register_forward_hook(
                    post_hook({ 'name': name, 'class': module.__class__.__name__ }))
        return handles

    def deregister_forward_hooks(self):
        for handle in self.handles_forward_hooks.values():
            handle.remove()

But I got a very unintuitive result as I have described by the title:

Memory allocated after execute 0-Conv2d: 609.58208 MB
Memory allocated after execute 1-BatchNorm2d: 643.136512 MB
Memory allocated after execute 2-ReLU: 643.136512 MB
Memory allocated after execute 3-Conv2d: 643.136512 MB
Memory allocated after execute 4-BatchNorm2d: 643.136512 MB
Memory allocated after execute 5-ReLU: 643.136512 MB
Memory allocated after execute 6-MaxPool2d: 617.970688 MB
Memory allocated after execute 0-Conv2d: 601.193472 MB
Memory allocated after execute 1-BatchNorm2d: 617.970688 MB
Memory allocated after execute 2-ReLU: 617.970688 MB
Memory allocated after execute 3-Conv2d: 617.970688 MB
Memory allocated after execute 4-BatchNorm2d: 617.970688 MB
Memory allocated after execute 5-ReLU: 617.970688 MB
Memory allocated after execute 6-MaxPool2d: 605.912064 MB
Memory allocated after execute 0-Conv2d: 597.523456 MB
Memory allocated after execute 1-BatchNorm2d: 606.96064 MB
Memory allocated after execute 2-ReLU: 606.96064 MB
Memory allocated after execute 3-Conv2d: 606.96064 MB
Memory allocated after execute 4-BatchNorm2d: 606.96064 MB
Memory allocated after execute 5-ReLU: 606.96064 MB
Memory allocated after execute 6-Conv2d: 606.96064 MB
Memory allocated after execute 7-BatchNorm2d: 606.96064 MB
Memory allocated after execute 8-ReLU: 606.96064 MB
Memory allocated after execute 9-Conv2d: 606.96064 MB
Memory allocated after execute 10-BatchNorm2d: 606.96064 MB
Memory allocated after execute 11-ReLU: 606.96064 MB
Memory allocated after execute 12-MaxPool2d: 600.669184 MB
Memory allocated after execute 13-Conv2d: 595.426304 MB
Memory allocated after execute 14-BatchNorm2d: 597.523456 MB
Memory allocated after execute 15-ReLU: 597.523456 MB
Memory allocated after execute 16-Conv2d: 597.523456 MB
Memory allocated after execute 17-BatchNorm2d: 597.523456 MB
Memory allocated after execute 18-ReLU: 597.523456 MB
Memory allocated after execute 19-Conv2d: 597.523456 MB
Memory allocated after execute 20-BatchNorm2d: 597.523456 MB
Memory allocated after execute 21-ReLU: 597.523456 MB
Memory allocated after execute 22-Conv2d: 597.523456 MB
Memory allocated after execute 23-BatchNorm2d: 597.523456 MB
Memory allocated after execute 24-ReLU: 597.523456 MB
Memory allocated after execute 25-MaxPool2d: 594.377728 MB
Memory allocated after execute 26-Conv2d: 591.232 MB
Memory allocated after execute 27-BatchNorm2d: 591.232 MB
Memory allocated after execute 28-ReLU: 591.232 MB
Memory allocated after execute 29-Conv2d: 591.232 MB
Memory allocated after execute 30-BatchNorm2d: 591.232 MB
Memory allocated after execute 31-ReLU: 591.232 MB
Memory allocated after execute 32-Conv2d: 591.232 MB
Memory allocated after execute 33-BatchNorm2d: 591.232 MB
Memory allocated after execute 34-ReLU: 591.232 MB
Memory allocated after execute 35-Conv2d: 591.232 MB
Memory allocated after execute 36-BatchNorm2d: 591.232 MB
Memory allocated after execute 37-ReLU: 591.232 MB

As you can see, the memory usage reported by allocated_memory shows that GPU usage was decreasing!

Could anyone help me identify whether or not I might misuse any API in my code? Thanks

Autograd will store intermediate tensors needed for the gradient computation and will free other unused temporary tensors. Depending on the model architecture and the used modules larger tensors might not be needed and freeing them would thus reduce the memory usage.
To check your model you could try to verify the memory usage of each submodule/layer and compare it to your output.

1 Like

@ptrblck Sorry, but I still don’t know how to prevent this to happen. Did you mean that this is a normal behavior specific to PyTorch’s autograd implementation?

I would also check what happens if you add torch.cuda.synchronize() calls before each print just to rule out any possibility that temporary workspaces for convolutions are being counted in the memory usage. Additionally, if you are using checkpointing, then indeed some temporary activations would be garbage collected with their memory being reused which would free memory as @ptrblck suggested.

1 Like

I would also check what happens if you add torch.cuda.synchronize() calls before each print […]

@eqy First of all, many thanks for reading my post :slight_smile:!

I did, and I got the same result. I even tried adding torch.cuda.empty_cache before&after the synchronize() call, which is not included in the image below since no effect. This is the callback for the register_forward_hook.

Additionally, if you are using checkpointing, then […]

I’m happy that you mentioned checkpointing. Could you shed some light on how could I check this: Say I’ve divided my model into many Segments (are nn.Module), I want to know whether torch.utils.checkpoint will use some allocator(so far I only know there are two kinds: sync&async one) to reserve GPU memory for all nodes (e.g. build-in classes like nn.Conv2d, etc) of a Segment at once before it will forward them one-by-one. Otherwise, PyTorch might allocate the GPU memory before running each node, which seems unlikely as GPU memory allocation is expensive.

(Sorry for asking an additional question after the reply. My intention is to make my reply compact, and any insight into the question itself will help me a lot)

You can read the implementation of torch.utils.checkpoint directly here: torch.utils.checkpoint — PyTorch 2.0 documentation
(The sequential version will be in the same source file in case you are using that instead)
You will see that there is mostly nothing special about the checkpointing use case in the forward pass (especially with regard to calling an allocator explicitly or switching allocators), as the main difference is that it is running each checkpointed segment under a with torch.no_grad(): context. This context allows the intermediate forward activation(s) to be discarded within each segment rather than being saved for backward, which explains why the memory usage can increase and then decrease within a segment. If you check the memory usage at the beginning and end of each segment, rather than each layer, I believe you should see nondecreasing memory usage.

Simply put, if you have a model with layers A->B->C->D->E->F, and segments A->B->C, D->E->F, all checkpointing will do is discard the inputs and outputs of B and E when they are no longer needed after the forward pass (this is automatically done by putting A->B->C and D->E->F in the no-grad context), and recompute them by rerunning A->B and D->E when they are needed again in the backward pass.

Finally, as you noted, GPU memory allocation via raw cudaMalloc calls is expensive, but the caching allocator continuously recycles allocations by maintaining its own pool(s) of GPU memory. CUDA semantics — PyTorch 2.0 documentation In practice once a training loop is warmed-up, we would not expect to see any more raw cudaMalloc calls.

1 Like

@eqy I just confirmed this, thank you! :slight_smile: So, to conclude: in fact, the torch.utils.checkpoint is not related to GPU memory allocation, which is the job of the caching allocator. But being recycled or not, which can be controlled by with torch.no_grad():, is indeed related to torch.utils.checkpoint as per implementation. (please correct me if I did it wrong.)

TIL these two facts, appreciate your explanations:

  • This context allows the intermediate forward activation(s) to be discarded within each segment rather than being saved for backward, which explains why the memory usage can increase and then decrease within a segment .
  • In practice, once a training loop is warmed-up, we would not expect to see any more raw cudaMalloc calls

If you’re still interested in this topic, one last question please:

Now, for me, the advantage of API torch.utils.checkpoint is that it can reduce the GPU memory usage(i.e. torch.cuda.memory_allocated) within a fixed pre-allocated amount by the caching allocator. Would it be practical to rewrite the (or maybe one of those) caching allocator so that the pre-allocation of GPU memory by the allocator won’t be too much larger than the memory actually being used, i.e. the value of torch.cuda.memory_allocated()?

My intention behind this is simple: For me, torch.utils.checkpoint is used to achieve “training a large model using a small GPU”. The pre-allocation of caching allocator seems to negate this effect. (I haven’t investigated how PyTorch implementing those caching allocators, so I’m probably wrong with this conclusion.) Could you also share some insight regarding this concern?

If you are concerned that the caching allocator is “wasting” memory by preallocating more than what is actually used, let me clarify that this should not be the case unless you are concerned with fragmentation (which I don’t think is the intent of your question). When there is more GPU memory available than is actually being used, the caching allocator may reserve more memory than is needed, but at the limit (your case of training a large model on a small GPU), the caching allocator can only reserve as much memory as the GPU has, and all of this memory would be available to model training (minus some hopefully small fraction due to fragmentation and GPU kernels). From this perspective, the caching allocator is meant to provide speedy allocations with minimal actual wasted memory usage.

With that said, the main limitation of torch.utils.checkpoint when it comes to training large models with the least amount of memory possible is that it can only recompute intermediate activations at most once. In theory, a more sophisticated checkpointing algorithm could recompute intermediate activations many times. Consider the A->B->C->D->E->F example shown previously. A more sophisticated checkpointing algorithm could store just the input to A and the output of F following the forward pass, and e.g.,
recompute B, recompute C, throw out B, recompute D, throw out C, recompute E, throw out D, compute F’s backward, throw out F, recompute B, recompute C, throw out B, recompute D, throw out C, compute E’s backward, throw out E, … and so on. This approach is more formally described and implemented in this paper: [2006.09616] Dynamic Tensor Rematerialization which you may be interested in.

1 Like

[…] the caching allocator can only reserve as much memory as the GPU has, and all of this memory would be available to model training […] From this perspective, the caching allocator is meant to provide speedy allocations with minimal actual wasted memory usage.

@eqy Could you elaborate more on this part? I think I haven’t been able to understand it fully. When a cache allocator reserves say 20GiB of GPU memory, and only 12GiB were used during the training, isn’t this a kind of wasting memory? From my current intuition (since I haven’t fully read the implementation details), the part being reserved would not be able to be used by other programs. In short, I think I might misunderstand your words.

I’m fine with the other part of your last comment, and I will read the paper you kindly linked.

Again, many thanks for your reading!

Regarding the first point, PyTorch is usually expected to be the only process running on a GPU. If you know the maximum amount of memory PyTorch should use and expect to use memory for other process, you can limit the amount of memory the caching allocator can reserve with e.g., torch.cuda.set_per_process_memory_fraction — PyTorch 2.0 documentation

1 Like

[…] the caching allocator can only reserve as much memory as the GPU has, and all of this memory would be available to model training (minus some hopefully small fraction due to fragmentation and GPU kernels).
scroll to the reply

@eqy So am I correct to say that there are some algorithms that PyTorch has implemented for the caching allocator to predict how much GPU will be reserved before the training? Since you only mentioned that it will “at most” use all the memory of the current single GPU. (Am I correct to say that this is the reason for doing the warm-up round in the first few batches?)

If you know the maximum amount of memory PyTorch should use and expect to use memory for other process […] you can limit the amount of memory the caching allocator can reserve with e.g., torch.cuda.set_per_process_memory_fraction […]

Yes, I’m working on an algorithm to make the prediction. So I will assume that this API will be useful only if the default algorithm of caching allocator might not set the optimal fraction(a memory top-bound that is very close to what will actually be used), right? I think I need to know where are these files located in the source, could you help? (I selected your reply above due to the title. So, this thread basically ends. But I’m more interested in our further discuss here.)

Again, it’s amazing that you replied all these with valuable details and references, very thank you!

PyTorch by default should not make strong assumptions about how much memory is used during training as it is meant to support dynamism during training. The caching allocator’s reservation of additional memory is based on heuristics rather than trying to “predict” how much memory should be reserved. I think the code for allocation size calculation/rounding might be relevant: pytorch/CUDACachingAllocator.cpp at d5f15d351510d1cb5e2294deb07569d343eea51f · pytorch/pytorch · GitHub

As the caching allocator has no foresight into the maximum memory required and will need to reserve it on-the-fly at least once, warmup iteration(s) will observe raw cudaMalloc calls while later iterations should not.

As for setting the maximum memory via the fraction environment variable, that code should be in the same file at pytorch/CUDACachingAllocator.cpp at main · pytorch/pytorch (github.com)