How to replace customized module in FX?

MyOldMod(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print("IN MyOldMod")
        y = x * 1
        return y


class MyNewMod(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        print("IN MyNewMod")
        y = x + 2
        return y

class MyNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1a = MyOldMod()
        self.bn1 = torch.nn.BatchNorm2d(1)

    def forward(self, x):
        print("In MyNet!")
        return self.bn1(self.conv1a(x))

This is just a trivial example! My goal is to replace MyOldMod to MyNewMod, while torch.fx seems can’t detect customized module, i.e. MyOldMod here, what should I do?

Anyone can help to solve this problem?