Thanks. I wrote this function to calculate the activation size of a network:
total_output_elements = 0
def calc_total_activation_size(model, call_the_network_function):
global total_output_elements
total_output_elements = 0
def hook(module, input, output):
global total_output_elements
total_output_elements += output.numel()
handle = torch.nn.modules.module.register_module_forward_hook(hook)
result = call_the_network_function()
handle.remove()
return result, total_output_elements