Conditional execution of forward hooks

I am using some forward hooks on intermediate layers for debugging/visualisation purposes and I am wondering if it is possible to execute the “hook” code only on each n-th batch? In other words is there any global modifier which toggles hook execution on and of?

I see two options.

When registering a hook, the handle is returned that can be used as a context-manager. When this context-manager exits, the hook gets deactivated. This is not probably documented in the offical docs but you can check the source code for that.

The snippet below shows how this works:

with my_net.layer.register_forward_hook(my_hook):
    # here the hook is active
    my_net(input)

# here the hook no longer gets called

Or you could write your own decorator that executes a hook every N calls.

def execute_every(n):
    def decorator(f):
        cnt_exec = 0
        
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            nonlocal cnt_exec
            print(cnt_exec)
            cnt_exec += 1
            if cnt_exec % n == 0:
                f(*args, **kwargs)
    
        return wrapper
    return decorator

@execute_every(5)
def my_hook(...):
    ...

my_net.layer.register_forward_hook(my_hook)
1 Like

Oh it’s that simple :slight_smile: Cool and thank you!