Monkey Patching Failing with __getitem__

Below is a complete NN that learns to multiply by 10. It’s job is irrelevant as it is just a tight example to illustrate the problem.

At the beginning of my forward, I have these lines (gibberish for this example, but in my larger project are crucial)

y = x<8.0
x[y] = 0

This causes my monkey patching to crash, even though, in this example, I’m doing nothing other than returning the result. If I ignore _getitem, it functions. So it has something to do with the slicing.

If I don’t ignore getitem, my error is:

“The shape of the mask [6] at index 0 does not match the shape of the indexed tensor [1, 6] at index 0”

if I ignore getitem, then it “works”, but lines like

x[:,1:3] = 0

No longer function (the above line works fine when not ignoring getitem)

Note, my wrapper function should be identity:

def _add_wrapper(mod, fn_name):
    func = getattr(mod, fn_name)

    def call(*args, **kwargs):
        # Call the original function
        result = func(*args, **kwargs)
        return result

    setattr(mod, fn_name, call)

My goal is to graph out the execution graph of pytorch, so if there is an easier way to accomplish this, I’m ears. For classes, I can use register_forward_pre_hook and register_forward_hook and do what I need flawlessly. But since Pytorch allows intermixing of functional forwards, those are difficult to hook into and capture when they are executing, so I’m trying monkey patching with limited success (I have gotten it to work on some models but not others, depending on the operations used)

The below should train in pretty much any environment, as it is extremely simple, just uses the CPU, and trains to completion in a second or two.

To see the code run to completion either:

  1. comment the first two lines in forward
  2. add getitem to the ignore list
  3. don’t monkey patch
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim

import inspect as ins

def _isfunc(mod, f):
    if not hasattr(mod, f):
        return False

    attr = getattr(mod, f)

    # Ignore functions from this list
    # if __setitem__ and __getitem__ are ignored, the crash goes away, but I also
    # cannot capture slicing operations.
    ignore = [
        '__all__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__builtins__', '__cached__',
        '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__file__',
        '__format__', '__getattribute__', '__hash__', '__index__', '__init__', '__init_subclass__',
        '__iter__', '__len__', '__loader__', '__module__', '__name__', '__new__', '__nonzero__', '__package__',
        '__path__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', # '__setitem__', '__getitem__',
        '__setstate__', '__sizeof__', '__spec__', '__str__', '__subclasshook__', '__version__', '__weakref__'
    ]

    # # Add functions to this list if they cause recursion
    ignore += ['size', 'tolist', 'dim', 'is_storage', 'item', 'is_grad_enabled']
    if f in ignore:
        return False

    return ins.ismethod(attr) or ins.isfunction(attr) or ins.ismethoddescriptor(attr) or ins.isbuiltin(attr)


def _add_wrapper(mod, fn_name):
    func = getattr(mod, fn_name)

    def call(*args, **kwargs):
        # Call the original function
        result = func(*args, **kwargs)
        return result

    setattr(mod, fn_name, call)


def _patchClass(cls):
    for f in dir(cls):
        if _isfunc(cls, f):
            _add_wrapper(cls, f)

# Monkey-patch classes in torch
def _patch_torch_classes():
    for cls in [
        torch,
        torch.Tensor,       # this is the class that is causing problems
        torch.nn.functional,
        torch.cuda,
        torch.distributed
    ]:
        _patchClass(cls)

# Monkey-patch forward functions in torch.nn libraries
def _patch_torch_nn_forward_functions():
    for cls in [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]:
        if _isfunc(cls, 'forward'):
            _add_wrapper(cls, 'forward')


# Simple Net that learns to multiply by 10
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.dense1 = nn.Linear(6, 6)

    def forward(self, x):
        # the following two lines will crash if monkey-patching is enabled.
        y = x<8.0
        x[y] = 0
        # This line does not cause a crash if monkey-patching is enabled.
        # x[:,1:3] = 0

        x = x.flatten()
        x = self.dense1(x)

        return x

def main():
    net = Net()

    # monkey patch pytorch
    _patch_torch_classes()
    _patch_torch_nn_forward_functions()

    # Loss
    criterion = nn.L1Loss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

    # 2x1x6 - input
    pth_input0 = torch.tensor(np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), dtype=torch.float32)
    pth_input1 = torch.tensor(np.array([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]), dtype=torch.float32)
    pth_input = [pth_input0, pth_input1]

    # 2x1x6 - labels
    pth_ref0 = torch.tensor(np.array([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]]), dtype=torch.float32)
    pth_ref1 = torch.tensor(np.array([[70.0, 80.0, 90.0, 100.0, 110.0, 120.0]]), dtype=torch.float32)
    pth_ref = [pth_ref0, pth_ref1]

    # Training Loop
    for epoch in range(100):  # loop over the dataset multiple times

        for batch_index, data in enumerate(pth_input):
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(pth_input[batch_index])
            loss = criterion(outputs, pth_ref[batch_index])
            loss.backward()
            optimizer.step()

            print('[%d] loss: %.3f' % (epoch + 1, loss.item()))

    print(outputs.detach().cpu().numpy())
    print('Finished Training')


if __name__ == '__main__':
    main()

More information…

        # With Monkey Patching:  SUCCEEDS
        # Without Monkey Patching:  FAILS
        # Note w is of shape [6]
        # w = torch.tensor(np.array([True, True, True, True, True, True]), dtype=torch.bool)
        # x[w] = 0

        # with monkey patching:  FAILS
        # Without Monkey Patching:  SUCCEEDS
        # Note w is of shape [1,6]
        # w = torch.tensor(np.array([[True, False, True, False, True, False]]), dtype=torch.bool)
        # x[w] = 0

        # with monkey patching:  SUCCEEDS
        # Without Monkey Patching:  SUCCEEDS
        x[0, [True, False, True, False, True, False]] = 0

So getitem is somehow modifying the shape of the tensor or wrapping getitem is somehow breaking promotion of certain array shapes properly.

Or possibly how the overload using Tensor vs List is functioning.