Below is a complete NN that learns to multiply by 10. It’s job is irrelevant as it is just a tight example to illustrate the problem.
At the beginning of my forward, I have these lines (gibberish for this example, but in my larger project are crucial)
y = x<8.0
x[y] = 0
This causes my monkey patching to crash, even though, in this example, I’m doing nothing other than returning the result. If I ignore _getitem, it functions. So it has something to do with the slicing.
If I don’t ignore getitem, my error is:
“The shape of the mask [6] at index 0 does not match the shape of the indexed tensor [1, 6] at index 0”
if I ignore getitem, then it “works”, but lines like
x[:,1:3] = 0
No longer function (the above line works fine when not ignoring getitem)
Note, my wrapper function should be identity:
def _add_wrapper(mod, fn_name):
func = getattr(mod, fn_name)
def call(*args, **kwargs):
# Call the original function
result = func(*args, **kwargs)
return result
setattr(mod, fn_name, call)
My goal is to graph out the execution graph of pytorch, so if there is an easier way to accomplish this, I’m ears. For classes, I can use register_forward_pre_hook and register_forward_hook and do what I need flawlessly. But since Pytorch allows intermixing of functional forwards, those are difficult to hook into and capture when they are executing, so I’m trying monkey patching with limited success (I have gotten it to work on some models but not others, depending on the operations used)
The below should train in pretty much any environment, as it is extremely simple, just uses the CPU, and trains to completion in a second or two.
To see the code run to completion either:
- comment the first two lines in forward
- add getitem to the ignore list
- don’t monkey patch
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import inspect as ins
def _isfunc(mod, f):
if not hasattr(mod, f):
return False
attr = getattr(mod, f)
# Ignore functions from this list
# if __setitem__ and __getitem__ are ignored, the crash goes away, but I also
# cannot capture slicing operations.
ignore = [
'__all__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__builtins__', '__cached__',
'__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__file__',
'__format__', '__getattribute__', '__hash__', '__index__', '__init__', '__init_subclass__',
'__iter__', '__len__', '__loader__', '__module__', '__name__', '__new__', '__nonzero__', '__package__',
'__path__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', # '__setitem__', '__getitem__',
'__setstate__', '__sizeof__', '__spec__', '__str__', '__subclasshook__', '__version__', '__weakref__'
]
# # Add functions to this list if they cause recursion
ignore += ['size', 'tolist', 'dim', 'is_storage', 'item', 'is_grad_enabled']
if f in ignore:
return False
return ins.ismethod(attr) or ins.isfunction(attr) or ins.ismethoddescriptor(attr) or ins.isbuiltin(attr)
def _add_wrapper(mod, fn_name):
func = getattr(mod, fn_name)
def call(*args, **kwargs):
# Call the original function
result = func(*args, **kwargs)
return result
setattr(mod, fn_name, call)
def _patchClass(cls):
for f in dir(cls):
if _isfunc(cls, f):
_add_wrapper(cls, f)
# Monkey-patch classes in torch
def _patch_torch_classes():
for cls in [
torch,
torch.Tensor, # this is the class that is causing problems
torch.nn.functional,
torch.cuda,
torch.distributed
]:
_patchClass(cls)
# Monkey-patch forward functions in torch.nn libraries
def _patch_torch_nn_forward_functions():
for cls in [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]:
if _isfunc(cls, 'forward'):
_add_wrapper(cls, 'forward')
# Simple Net that learns to multiply by 10
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense1 = nn.Linear(6, 6)
def forward(self, x):
# the following two lines will crash if monkey-patching is enabled.
y = x<8.0
x[y] = 0
# This line does not cause a crash if monkey-patching is enabled.
# x[:,1:3] = 0
x = x.flatten()
x = self.dense1(x)
return x
def main():
net = Net()
# monkey patch pytorch
_patch_torch_classes()
_patch_torch_nn_forward_functions()
# Loss
criterion = nn.L1Loss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
# 2x1x6 - input
pth_input0 = torch.tensor(np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), dtype=torch.float32)
pth_input1 = torch.tensor(np.array([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]), dtype=torch.float32)
pth_input = [pth_input0, pth_input1]
# 2x1x6 - labels
pth_ref0 = torch.tensor(np.array([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]]), dtype=torch.float32)
pth_ref1 = torch.tensor(np.array([[70.0, 80.0, 90.0, 100.0, 110.0, 120.0]]), dtype=torch.float32)
pth_ref = [pth_ref0, pth_ref1]
# Training Loop
for epoch in range(100): # loop over the dataset multiple times
for batch_index, data in enumerate(pth_input):
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(pth_input[batch_index])
loss = criterion(outputs, pth_ref[batch_index])
loss.backward()
optimizer.step()
print('[%d] loss: %.3f' % (epoch + 1, loss.item()))
print(outputs.detach().cpu().numpy())
print('Finished Training')
if __name__ == '__main__':
main()