class Network(nn.Module):
self.network1()
self.network2()
def forward(self, x):
x = self.network1(x)
y = self.network2(x)
return x, y
class Classifier(nn.Module):
self.classifier()
...
What I want to do is
step1)
when Epoch < 10:
train classifier with whole Network()
step2)
when Epoch >= 10:
train classifier with only network1() in Network()
How can I freeze(?) the network2 in step2 ?