Hi,
I am in trouble in setting up only some specific networks’ requres_grad in my network to True or False.
class MyNetwork:
…
self.layer1 = Layer1()
self.layer2 = Layer2()
Some other networks exist.
…
class Layer1():
…
self.net_a = torch.nn.Linear(xx)
Some other networks exist.
…
class Layer2():
self.net_b = torch.nn.Linear(yy)
Some other networks exist.
…
Then I want set up only
my_network.layer1.net_a,
and
my.network.layer2.net_b
to be requires_grad = False.
my_network.parameters() gives iterator, by which you can set up reuires_grad of all networks as follows;
for param in my_network.parameters():
param.requires_grad = True (or False)
however, I can not figure out the way around my_network.parameters() to specify particular network names to be targeted for requires_grad setting.
I investigated my_network.named_modules() but it just returns all subnet names under my_net as an iterator but never gives a linkage to requires_grad setting.
Thank you for your advice in advance.