Hi,
I am not sure if I understand your question.
It seems that you have two network: subnetwork1
and subnetwork2
, subnetwork2
is pre-trained, and what you want looks like as below:
data -> subnetwork1 -> output1 -> subnetwork2 -> loss
And the params of pre-trained subnetwork2
do not update.
I think you can set requires_grad=False
to params of subnetwork2
, this thread may help you.