Unexpected behavior of register_hook when increasing the batch size

I tried to profile the model by using the hook, but I found the behavior is not expected when I increase the batch size for batch normalization layers. Here is the minimal reproducible example:

import copy

import torch
from torch import nn
from torchvision.models.mobilenetv3 import mobilenet_v3_small


def register_bn_counter(net, input_size):
    FWD_TOTAL = torch.Tensor([0])
    BWD_TOTAL = torch.Tensor([0])
    NUM_BN_LAYERS = 0

    net_ = copy.deepcopy(net)
    for m in net_.modules():
        if hasattr(m, "inplace"):
            m.inplace = False
    
    hook_handlers = []

    def count_bn_fwd(m, x, y):
        m.counter = torch.Tensor([x[0].numel() * 4]) # byte

    def count_bn_bwd(m, gx, gy):
        m.counter += torch.Tensor([1])

    def add_hooks(m_):

        if type(m_) in [nn.BatchNorm2d]:
            m_.register_buffer("counter", torch.zeros(1))
            fn_fwd = count_bn_fwd
            fn_bwd = count_bn_bwd
            fn_fwd_handler = m_.register_forward_hook(fn_fwd)
            hook_handlers.append(fn_fwd_handler)
            fn_bwd_handler = m_.register_full_backward_hook(fn_bwd)
            hook_handlers.append(fn_bwd_handler)

    net_.apply(add_hooks)

    criterion = torch.nn.CrossEntropyLoss()

    x = torch.zeros(input_size)
    if list(net_.parameters()):
        x.to(net_.parameters().__next__().device)

    # do forward, all counters in bn layers record input feature map size in byte
    y = net_(x)
    for m in net_.modules():
        if hasattr(m, "counter"):
            NUM_BN_LAYERS += 1
            FWD_TOTAL += m.counter

    # do backward, all counters in bn layers should be added 1
    loss = criterion(y, torch.zeros(input_size[0], dtype=int))
    loss.backward()
    for m in net_.modules():
        if hasattr(m, "counter"):
            BWD_TOTAL += m.counter

    # expect tensor([34.]) 34, but it prints tensor([4.]) 34 when batch_size = 8
    print(BWD_TOTAL - FWD_TOTAL, NUM_BN_LAYERS)
    assert BWD_TOTAL - FWD_TOTAL == NUM_BN_LAYERS
    
    for h in hook_handlers:
        h.remove()

    return net_

def calculate_bn(net, input_size):
    net_ = copy.deepcopy(net)

    class ForwardProfiler():
        def __init__(self):
            self.sum = 0

        def hook(self, m, x, y):
            if hasattr(m, "counter"):
                self.sum += m.counter

    class BackwardProfiler():
        def __init__(self):
            self.sum = 0

        def hook(self, m, gx, gy):
            if hasattr(m, "counter"):
                self.sum += m.counter

    fwd_prof = ForwardProfiler()
    bwd_prof = BackwardProfiler()
    for m in net_.modules():
        if hasattr(m, "inplace"):
            m.inplace = False
        handler_ = m.register_forward_hook(fwd_prof.hook)
        handler_ = m.register_full_backward_hook(bwd_prof.hook)

    criterion = torch.nn.CrossEntropyLoss()

    x = torch.zeros(input_size)
    if list(net_.parameters()):
        x.to(net_.parameters().__next__().device)
    y = net_(x)
    loss = criterion(y, torch.zeros(input_size[0], dtype=int))
    loss.backward()
    print(fwd_prof.sum, bwd_prof.sum)
    assert fwd_prof.sum == bwd_prof.sum

def test():
    # model = mobilenet_v3_small()
    model = torch.nn.Sequential(
        torch.nn.Conv2d(
            in_channels=3, out_channels=16,
            kernel_size=(3, 3), stride=1, padding=0, dilation=1,
            groups=1, bias=True, padding_mode='zeros', device=None, dtype=None),
        nn.BatchNorm2d(num_features=16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None),
        nn.ReLU(inplace=False),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(start_dim=1),
    )
    model.train()
    input_size = (8, 3, 224, 224)  # assert error
    # input_size = (2, 3, 224, 224)  # when bsize = 1, 2, 3, work as expected
    model_bn_counter = register_bn_counter(model, input_size)
    calculate_bn(model_bn_counter, input_size)

if __name__ == '__main__':

    test()

Thanks

I haven’t checked what your code is exactly doing, but you are dealing with values larger than 2**24 in e.g. m.counter in float32 which will start rounding towards powers of 2.
Check this Wikipedia article to see the precision limits.
E.g. using the larger batch size shows m.counter as 25233408, which is in [2**24, 2**25] and will round to multiples of 2:

torch.tensor(25233408.,dtype=torch.float32)
> tensor(25233408.)

torch.tensor(25233409.,dtype=torch.float32)
> tensor(25233408.)

torch.tensor(25233409.,dtype=torch.float64)
> tensor(25233409., dtype=torch.float64)

Initialize m.counter and _TOTAL as float64 and it would work.

1 Like

Thank you! I use DoubleTensor and it works!

import copy

import torch
from torch import nn
from torchvision.models.mobilenetv3 import mobilenet_v3_small


def register_bn_counter(net, input_size):
    FWD_TOTAL = torch.DoubleTensor([0])
    BWD_TOTAL = torch.DoubleTensor([0])
    NUM_BN_LAYERS = 0

    net_ = copy.deepcopy(net)
    for m in net_.modules():
        if hasattr(m, "inplace"):
            m.inplace = False
    
    hook_handlers = []

    def count_bn_fwd(m, x, y):
        m.counter = torch.DoubleTensor([x[0].numel() * 4]) # byte

    def count_bn_bwd(m, gx, gy):
        m.counter += torch.DoubleTensor([1])

    def add_hooks(m_):

        if type(m_) in [nn.BatchNorm2d]:
            m_.register_buffer("counter", torch.zeros(1, dtype=torch.float64))
            fn_fwd = count_bn_fwd
            fn_bwd = count_bn_bwd
            fn_fwd_handler = m_.register_forward_hook(fn_fwd)
            hook_handlers.append(fn_fwd_handler)
            fn_bwd_handler = m_.register_full_backward_hook(fn_bwd)
            hook_handlers.append(fn_bwd_handler)

    net_.apply(add_hooks)

    criterion = torch.nn.CrossEntropyLoss()

    x = torch.zeros(input_size)
    if list(net_.parameters()):
        x.to(net_.parameters().__next__().device)

    # do forward, all counters in bn layers record input feature map size in byte
    y = net_(x)
    for m in net_.modules():
        if hasattr(m, "counter"):
            NUM_BN_LAYERS += 1
            FWD_TOTAL += m.counter

    # do backward, all counters in bn layers should be added 1
    loss = criterion(y, torch.zeros(input_size[0], dtype=int))
    loss.backward()
    for m in net_.modules():
        if hasattr(m, "counter"):
            BWD_TOTAL += m.counter

    # expect tensor([34.]) 34, but it prints tensor([4.]) 34 when batch_size = 8
    print(BWD_TOTAL - FWD_TOTAL, NUM_BN_LAYERS)
    assert BWD_TOTAL - FWD_TOTAL == NUM_BN_LAYERS
    
    for h in hook_handlers:
        h.remove()

    return net_

def calculate_bn(net, input_size):
    net_ = copy.deepcopy(net)

    class ForwardProfiler():
        def __init__(self):
            self.sum = torch.DoubleTensor([0])

        def hook(self, m, x, y):
            if hasattr(m, "counter"):
                self.sum += m.counter

    class BackwardProfiler():
        def __init__(self):
            self.sum = torch.DoubleTensor([0])

        def hook(self, m, gx, gy):
            if hasattr(m, "counter"):
                self.sum += m.counter

    fwd_prof = ForwardProfiler()
    bwd_prof = BackwardProfiler()
    for m in net_.modules():
        if hasattr(m, "inplace"):
            m.inplace = False
        handler_ = m.register_forward_hook(fwd_prof.hook)
        handler_ = m.register_full_backward_hook(bwd_prof.hook)

    criterion = torch.nn.CrossEntropyLoss()

    x = torch.zeros(input_size)
    if list(net_.parameters()):
        x.to(net_.parameters().__next__().device)
    y = net_(x)
    loss = criterion(y, torch.zeros(input_size[0], dtype=int))
    loss.backward()
    print(fwd_prof.sum, bwd_prof.sum)
    assert fwd_prof.sum == bwd_prof.sum

def test():
    # model = mobilenet_v3_small()
    model = torch.nn.Sequential(
        torch.nn.Conv2d(
            in_channels=3, out_channels=16,
            kernel_size=(3, 3), stride=1, padding=0, dilation=1,
            groups=1, bias=True, padding_mode='zeros', device=None, dtype=None),
        nn.BatchNorm2d(num_features=16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None),
        nn.ReLU(inplace=False),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(start_dim=1),
    )
    model.train()
    input_size = (8, 3, 224, 224)
    model_bn_counter = register_bn_counter(model, input_size)
    calculate_bn(model_bn_counter, input_size)

if __name__ == '__main__':

    test()