Memory size of all tensors referenced by autograd graph?

Is there a way to get statistics of memory used by activations/intermediate tensors in the autograd graph (after the forward pass completed loss computation)?

This is useful for debugging insufficient memory savings from autocast.

You could use forward hooks to check the size of the intermediate forward activations.

But it won’t take into account certain tensors being garbage-collected, inplace ops and some intermediate tensor saved_for_backward within the C++ ops, right?

Basically, the question is still on not easily observable memory consumption in combination with unclearness what gets actually stored as fp16 as a single copy and how/where casts are done and whether they are done on the fly or not. Autocast is being too “auto” and too magical :slight_smile:

It might be good to somehow have tracing mode to understand at what point are casts executed (while in eager mode).

To check for casts and track operations you could use something like this:

import torch
from torch.testing._internal.logging_tensor import LoggingTensorMode, capture_logs

model = torch.nn.Linear(10, 10).cuda()
x = torch.randn(1, 10).cuda()

with capture_logs(is_mode=True) as logs, LoggingTensorMode():
    with torch.autocast(device_type='cuda'):
        out = model(x)

for l in logs:

which should reduce the “magic”.
I’m not aware of another way to check for intermediates created in the backend.

Curious, how will it show “cached” casts in autocast? What casts are “cached” in autocast? Casts of model parameters from fp32 to fp16, right? At what moments is “cache” invalidated or flushed?

Is autocast implemented as some sort of wrapper for function calls? E.g. if I call several functions on activation of a previous module, will there be two redundant different casts fp32->fp16 for two branches of computation? magic :frowning:

This logging might also be much better if these extensions were implemented: subclass_zoo/ at main · albanD/subclass_zoo · GitHub (adding shape and dtype to the text trace)

Related issue: Modernize logging tensor in torch.testing._internal · Issue #81750 · pytorch/pytorch · GitHub