I have a situation where I want to sometimes have an input only pass gradients to the top layers, and sometimes to all of the layers. The way I am currently accomplishing this is by doing this…
def forward(self, x, top=True):
x = self.first_layers(x)
if top:
x = x.detach()
x.requires_grad_(True)
return self.last_layers(x)
I am wondering if this is actually accomplishing what I intend or not and if there is a better way to do this? Thanks