I have a Class like this, I want to freeze Net1 and Net3 during my training.
Would it be possible?
class MyNetwork(nn.Module):
def __init__(self, a, b, c, d):
super(MyNetwork, self).__init__()
self.a = a
self.b = b
self.c = c
self.d = d
self.Net1 = nn.Linear(a, b)
self.Net2 = nn.Linear(b, c)
self.Net3 = nn.Linear(c, d)
def func1(self, x):
h_x = self.Net1(x)
h_x = self.Net2(h_x)
return h_x
def func2(self, x):
h_x2 = self.Net3(x)
return h_x2
def forward(self, x):
y = self.func1(x)
z = self.func2(y)
return z
I was thinking maybe I can do something like this, but I feel like this is wrong,
class MyNetwork(nn.Module):
def __init__(self, a, b, c, d):
super(MyNetwork, self).__init__()
self.a = a
self.b = b
self.c = c
self.d = d
self.Net1 = nn.Linear(a, b)
self.Net2 = nn.Linear(b, c)
self.Net3 = nn.Linear(c, d)
def func1(self, x):
h_x = self.Net1(x).eval()
h_x = self.Net2(h_x)
return h_x
def func2(self, x):
h_x2 = self.Net3(x).eval()
return h_x2
def forward(self, x):
y = self.func1(x)
z = self.func2(y)
return z