Right now, I’m only trying to train a discriminator to always classify as ‘1’. Its input comes from the last convolutional layer of resnet (a 2056x1x1 slice through the channels).
I pass the output of the discriminator through a BCELoss criterion, and call .backward() on that loss.
However, if I .detach() the input to the discriminator, or if I freeze the layers of resnet before calling loss.backward(), the discriminator doesn’t learn a thing. If I don’t detach/freeze the weights, the network very quickly learns, but of course it doesn’t help me to have the gradient updating in the encoder while I’m training the discriminator.
I’ve checked that requires_grad=True for the parameters of the discriminator, and have tried it with several different discriminator implementations.
I’ve ran out of ideas.
Probably worth mentioning that the one thing I’m doing that the examples I’ve looked at aren’t is connecting the resnet encoder to the discriminator in the forward() function of my parent nn.Module class that connects everything.