Hi,
Since the pre/post hook functionality and .creator switch isn’t implemented yet in master, we’re trying to implement a simple profiling/visualization tool by monkey patching all of the functions we can find in the module.
We’re having trouble navigating all the different types and format of functions including THNN built-in, PyTorch C++ built-in, autogenerated backwards nodes, and some edge cases like BatchNormBackward. All but the last we managed to solve using the code below.
-
How can we hook into these
BatchNormBackward
nodes? They do not have forward/backward/apply and seem to only implement__call__
, which is immutable from Python. This seems different from all of the other functions–why? -
Sometimes THNN implemented torch.nn functions have the instance (ctx/self?) passed in with their Backward type e.g. View becomes ViewBackward during the forward call. Why does this happen?
Solved: this is because the instance/context is actually not guaranteed to be any type. It just needs to be able to holdsaved_tensors
etc. See this pull request. -
Is there a currently better way to find a list of all possible graph notes in PyTorch and monkey patch them?
Any clarifications on the current state of the compute graph would help immensly as well.
Thanks!
from torch.autograd import function as base_functions
from torch.autograd import _functions as _autograd_funcs
from torch.nn import _functions as _nn_funcs
import torch.nn._functions.thnn.auto as _auto
import torch._C as _C
from torch import nn
from torch.autograd import Function, Variable
def _wrap_func(func, pre_func, post_func):
@functools.wraps(func)
def run(*args, **kwargs):
pre_func(*args, **kwargs)
result = func(*args, **kwargs)
post_func(*args, **kwargs, result=result)
return result
return run
class MonkeyPatcher:
_FN_BLACKLIST = {'Function', 'InplaceFunction', 'BackwardCFunction'}
def __init__(self):
self.level = 0
self._patched_fn_classes = set()
self._time_begin_by_fn = {}
self._time_end_by_fn = {}
def monkey_patch(self):
all_funcs = itertools.chain(
_autograd_funcs.__dict__.values(),
_nn_funcs.__dict__.values(),
_auto.__dict__.values(),
nn.modules.__dict__.values(),
base_functions.__dict__.values())
for fn_cls in all_funcs:
if not isclass(fn_cls) or fn_cls.__name__ in self._FN_BLACKLIST:
continue
if fn_cls in self._patched_fn_classes:
continue
if not (issubclass(fn_cls, Function)
or issubclass(fn_cls, _C._FunctionBase)
or issubclass(fn_cls, nn.Module)):
continue
if hasattr(fn_cls, 'forward'):
fn_cls.forward = _wrap_func(
fn_cls.forward,
partial(self._forward_pre, fn_cls=fn_cls, fn_inst=None),
partial(self._forward_post, fn_cls=fn_cls, fn_inst=None))
if hasattr(fn_cls, 'backward')):
fn_cls.backward = _wrap_func(
fn_cls.backward,
partial(self._backward_pre, fn_cls=fn_cls, fn_inst=None),
partial(self._backward_post, fn_cls=fn_cls, fn_inst=None))
if hasattr(fn_cls, 'apply'):
fn_cls.apply = _wrap_func(
fn_cls.apply,
partial(self._forward_pre, fn_cls=fn_cls, fn_inst=None),
partial(self._forward_post, fn_cls=fn_cls, fn_inst=None))
self._patched_fn_classes.add(fn_cls)
def _indent(self):
return ' ' * self.level
def _forward_pre(self, *args, fn_cls: type, fn_inst, **kwargs):
if fn_inst is None:
fn_inst = args[0]
print(self._indent() + 'Forward Pre ' + fn_cls.__name__)
fn_inst.forward_time_begin = time.time()
self.level += 1
def _forward_post(self, *args, fn_cls: type, fn_inst, result, **kwargs):
if fn_inst is None:
fn_inst = args[0]
self.level -= 1
print(self._indent() + 'Forward Post ' + fn_cls.__name__)
fn_inst.forward_elapsed = time.time() - fn_inst.forward_time_begin
def _backward_pre(self, *args, fn_cls: type, fn_inst, **kwargs):
if fn_inst is None:
fn_inst = args[0]
print(self._indent() + 'Backward Pre ' + fn_cls.__name__)
fn_inst.backward_time_begin = time.time()
self.level += 1
def _backward_post(self, *args, fn_cls: type, fn_inst, result, **kwargs):
if fn_inst is None:
fn_inst = args[0]
self.level -= 1
print(self._indent() + 'Backward Post ' + fn_cls.__name__)
fn_inst.backward_elapsed = time.time() - fn_inst.backward_time_begin