I have a family of functions that follow the following structure for the forward method.
class ParentFunc(torch.autograd.Function):
@staticmethod
def forward(ctx):
output1 = ParentFunc.my_class_method1()
output2 = ParentFunc.my_class_method2(output1)
return output2
@classmethod
def my_class_method1(cls):
return compute1()
@classmethod
def my_class_method2(cls, output1):
return compute2(output1)
@staticmethod
def backward(ctx):
pass # not important right now
With this structure, I am able to implement a general case that works for a lot of my child functions by simply inheriting the forward() and class methods, which is great. The hope was that when I needed to do edge cases, I would only have to change a few class method and use the other inherited code, rather than copy-paste the entire code block.
See the following edge case example:
class ChildFunc(ParentFunc):
@staticmethod
def forward(ctx):
output1 = ChildFunc.my_class_method1() # new definiton
output2 = ParentFunc.my_class_method2(output1)
return output2
@classmethod
def my_class_method1(cls):
return compute1_child()
When running ChildFunc, I can’t get it to call the overridden forward() OR my_class_method1(). In VSCode, it is showing that these functions reside in ParentFunc. The function inputs and outputs are the same, which seems to be a requirement of Python overriding.
Looking for options, there is name mangling where you change forward() to _forward() or __forward(), but that doesn’t work with the PyTorch framework to automatically call forward() w/ things like .apply() or call. When doing name mangling like _forward() or __forward(), VSCode acknowledges that the new definition resides in ChildFunc.
Is there anything I can do to implement this with inheritance? I am not a Python or PyTorch expert, so I am hoping I am missing something.