Code:
import torch
import torch.fx
from types import MethodType
from copy import deepcopy
class Student(torch.nn.Module):
def init(self):
super().init()
self.bn = torch.nn.BatchNorm2d(3)
self.gamma = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(x)
def set_score(self, score):
self.score = score
s = Student()
func = getattr(s.class, “set_score”)
fx_s = torch.fx.symbolic_trace(s)
setattr(fx_s, “set_score”, MethodType(func, fx_s))
fx_s_copy = deepcopy(fx_s)
Bug: RecursionError: maximum recursion depth exceeded while calling a Python object