Determine inter-layer channel dependency

I want to implement a function that modifies the output channel of a layer on the fly. To ensure that the model is still valid after the modification, I will need to find out all layers that depends on that layer (and there can be trivial layers/functions like relu between them). Is there any existing infra/function in PyTorch that can help me with this functionality?