Pytorch appears to be crashing due to OOM prematurely?

Thanks. I wrote this function to calculate the activation size of a network:


total_output_elements = 0
def calc_total_activation_size(model, call_the_network_function):
    global total_output_elements
    total_output_elements = 0

    def hook(module, input, output):
        global total_output_elements
        total_output_elements += output.numel()
        
    handle = torch.nn.modules.module.register_module_forward_hook(hook)
    result = call_the_network_function()
    handle.remove()
    return result, total_output_elements