Is it possible to obtain the number of activations in a model for a given input (assuming that the size only depends on the shapes and not the actual values in the input tensor) without initializing the weights/buffers?
Here is an example that initializes the weights?
I tried doing this with fake mode, but it seems not all ops are implemented:
import torch
import torch.nn as nn
from contextlib import nullcontext
# from torchdistx.fake import fake_mode
# with fake_mode():
with nullcontext():
model = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
)
total_activations = 0
def count_activations_hook(module, input, output):
global total_activations
activations = output.numel()
total_activations += activations
print(f"{module.__class__.__name__} produced {activations} activations")
for layer in model.children():
layer.register_forward_hook(count_activations_hook)
input_tensor = torch.randn(1, 1, 28, 28)
model(input_tensor)
print(f"Total activations: {total_activations}")
@albanD just recently shared how to measure memory usage via TorchDispatchModehere. Applied to your model I see:
import torch
import torch.nn as nn
from torch.utils._pytree import tree_map_only
from torch.utils._python_dispatch import TorchDispatchMode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils.weak import WeakIdKeyDictionary
import weakref
import math
# Track all the memory being used by Tensors.
# Only max is tracked but others can be added.
MEMORY_USE = WeakIdKeyDictionary()
MEMORY_MAX = 0
# Minimum allocation size
PYTORCH_MIN_ALLOCATE = 2**9
def update_stats():
global MEMORY_MAX
curr_use = 0
for k, v in MEMORY_USE.items():
curr_use += math.ceil(k.size() * k.element_size()/PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE
if MEMORY_MAX < curr_use:
MEMORY_MAX = curr_use
# Should be called on every Tensor created
def track(t:torch.Tensor):
def cb(_):
update_stats()
st = t.untyped_storage()
wt = weakref.ref(st, cb)
MEMORY_USE[st] = wt
update_stats()
# Use this Mode to call track on every Tensor being created by functions
class MemoryTrackingMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
res = func(*args, **kwargs or {})
tree_map_only(torch.Tensor, track, res)
return res
with FakeTensorMode(), MemoryTrackingMode():
model = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
)
input_tensor = torch.randn(1, 1, 28, 28)
model(input_tensor)
model(input_tensor)
print(f"{MEMORY_MAX}")
output = model(input_tensor)
output = model(input_tensor)
print(f"with return: {MEMORY_MAX}")
@ptrblck I don’t know if I can ask you about some question with this program, I hope this is allowed. I run the code you shown, and I found that if I add 'model(input_tensor)’ line with three or more times, the result will be still. I don’t know why the result is this? May I think that the input data is not changed so the cpu memory we calculated is static so thre result is same.
PyTorch will cache and reuse device memory if you delete or overwrite variables. If you execute the same code multiple times the memory usage should be constant after the initial warmup phase.