Monkey Patch Hooking BatchNormBackward and ConvNdBackward

Hi,

Since the pre/post hook functionality and .creator switch isn’t implemented yet in master, we’re trying to implement a simple profiling/visualization tool by monkey patching all of the functions we can find in the module.

We’re having trouble navigating all the different types and format of functions including THNN built-in, PyTorch C++ built-in, autogenerated backwards nodes, and some edge cases like BatchNormBackward. All but the last we managed to solve using the code below.

  1. How can we hook into these BatchNormBackward nodes? They do not have forward/backward/apply and seem to only implement __call__, which is immutable from Python. This seems different from all of the other functions–why?

  2. Sometimes THNN implemented torch.nn functions have the instance (ctx/self?) passed in with their Backward type e.g. View becomes ViewBackward during the forward call. Why does this happen?
    Solved: this is because the instance/context is actually not guaranteed to be any type. It just needs to be able to hold saved_tensors etc. See this pull request.

  3. Is there a currently better way to find a list of all possible graph notes in PyTorch and monkey patch them?

Any clarifications on the current state of the compute graph would help immensly as well.

Thanks!

from torch.autograd import function as base_functions
from torch.autograd import _functions as _autograd_funcs
from torch.nn import _functions as _nn_funcs
import torch.nn._functions.thnn.auto as _auto
import torch._C as _C
from torch import nn
from torch.autograd import Function, Variable

def _wrap_func(func, pre_func, post_func):
    @functools.wraps(func)
    def run(*args, **kwargs):
        pre_func(*args, **kwargs)
        result = func(*args, **kwargs)
        post_func(*args, **kwargs, result=result)
        return result

    return run

class MonkeyPatcher:
    _FN_BLACKLIST = {'Function', 'InplaceFunction', 'BackwardCFunction'}

    def __init__(self):
        self.level = 0
        self._patched_fn_classes = set()
        self._time_begin_by_fn = {}
        self._time_end_by_fn = {}

   def monkey_patch(self):
        all_funcs = itertools.chain(
            _autograd_funcs.__dict__.values(),
            _nn_funcs.__dict__.values(),
            _auto.__dict__.values(),
            nn.modules.__dict__.values(),
            base_functions.__dict__.values())
        for fn_cls in all_funcs:
            if not isclass(fn_cls) or fn_cls.__name__ in self._FN_BLACKLIST:
                continue
            if fn_cls in self._patched_fn_classes:
                continue
            if not (issubclass(fn_cls, Function)
                    or issubclass(fn_cls, _C._FunctionBase)
                    or issubclass(fn_cls, nn.Module)):
                continue

            if hasattr(fn_cls, 'forward'):
                fn_cls.forward = _wrap_func(
                    fn_cls.forward,
                    partial(self._forward_pre, fn_cls=fn_cls, fn_inst=None),
                    partial(self._forward_post, fn_cls=fn_cls, fn_inst=None))
            if hasattr(fn_cls, 'backward')):
                fn_cls.backward = _wrap_func(
                    fn_cls.backward,
                    partial(self._backward_pre, fn_cls=fn_cls, fn_inst=None),
                    partial(self._backward_post, fn_cls=fn_cls, fn_inst=None))
            if hasattr(fn_cls, 'apply'):
                fn_cls.apply = _wrap_func(
                    fn_cls.apply,
                    partial(self._forward_pre, fn_cls=fn_cls, fn_inst=None),
                    partial(self._forward_post, fn_cls=fn_cls, fn_inst=None))
            self._patched_fn_classes.add(fn_cls)

    def _indent(self):
        return '  ' * self.level

    def _forward_pre(self, *args, fn_cls: type, fn_inst, **kwargs):
        if fn_inst is None:
            fn_inst = args[0]
        print(self._indent() + 'Forward Pre ' + fn_cls.__name__)
        fn_inst.forward_time_begin = time.time()
        self.level += 1

    def _forward_post(self, *args, fn_cls: type, fn_inst, result, **kwargs):
        if fn_inst is None:
            fn_inst = args[0]
        self.level -= 1
        print(self._indent() + 'Forward Post ' + fn_cls.__name__)
        fn_inst.forward_elapsed = time.time() - fn_inst.forward_time_begin

    def _backward_pre(self, *args, fn_cls: type, fn_inst, **kwargs):
        if fn_inst is None:
            fn_inst = args[0]
        print(self._indent() + 'Backward Pre ' + fn_cls.__name__)
        fn_inst.backward_time_begin = time.time()
        self.level += 1

    def _backward_post(self, *args, fn_cls: type, fn_inst, result, **kwargs):
        if fn_inst is None:
            fn_inst = args[0]
        self.level -= 1
        print(self._indent() + 'Backward Post ' + fn_cls.__name__)
        fn_inst.backward_elapsed = time.time() - fn_inst.backward_time_begin

Hi,

BatchNormBackward is different from the python Function because it is implemented purely in C++.
The goal of these C++ functions is to be able to run computations without having to deal with python (and especially the GIL), so it is expected to you cannot modify them from python.
You can still associate hooks to them as any other function with the register_hook() method.

The .creator has been changed to .grad_fn in master, you can look here for how to traverse the graph: https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py

Hi, the thing is we need prepended hooks (before the function is run) so that we can keep track of the computation time. I’m guessing there’s no good way to do this right now until the pre_hook/post_hook functionality is hooked up (heh) with a python interface?

Hi,

You can check the following snippet:

import time
import torch
from torch.autograd import Variable, Function, gradcheck
from torch import nn


current = 0
def hook_fn(msg):
    def tmp(*args):
        global current
        new = time.time()
        diff = new - current
        current = new
        print("Hello {} at {}".format(msg, diff))
    return tmp

a = Variable(torch.rand(10, 10000), requires_grad=True)
fn = nn.Linear(10000, 50000)

b = fn(a)

a.register_hook(hook_fn("var a"))
b.grad_fn.register_hook(hook_fn("grad fn"))
b.register_hook(hook_fn("var b"))

current = time.time()
b.sum().backward()

As you can see with the runtimes,
Basically, the hook on b is one of the hooks executed just before the linear function backward is called. And the one on b.grad_fn is one of the hooks executed just after the linear function backward is executed. And the one in a is also a hook just after the linear function backward. You can change the size of the linear output and see which part of the timing change.

tl:dr: You can access pre_hook/post_hook with the variable register_hook and variable.grad_fn register_hook.