I think you should be able to use torch.fx with its ability to manipulate the graph as described here. In the example they are replacing add() with mul() calls and I assume you can use the same or similar approach to replace the ReLU modules.
Thank you, but this solution doesn’t work - maybe I need to make it recursive…
It tries to find the function calls, but there might be these function calls within a sequential layer or any other embedded layer…