Is there a good way to clear the graph connected to a tensor after it has gone through half of the layers?

I have a situation where I want to sometimes have an input only pass gradients to the top layers, and sometimes to all of the layers. The way I am currently accomplishing this is by doing this…

def forward(self, x, top=True):
  x = self.first_layers(x)
  if top:
    x = x.detach()

  return self.last_layers(x)

I am wondering if this is actually accomplishing what I intend or not and if there is a better way to do this? Thanks

In fact detach is just to do that. Breaking the graph.
So it’s the proper way to do it