import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter class net(nn.Module): def __init__(self): super(net, self).__init__() self.layer1 = nn.Linear(10, 5) self.mask1 = Parameter(torch.ones(5)) self.layer2 = nn.Linear(5, 5) self.mask2 = Parameter(torch.ones(5)) self.layer3 = nn.Linear(5, 1) self.mask3 = Parameter(torch.ones(1)) def forward(self, x): x = self.layer1(x) x = torch.mul(self.mask1, x) x = F.leaky_relu(x) x = self.layer2(x) x = torch.mul(self.mask2, x) x = F.leaky_relu(x) x = self.layer3(x) x = torch.mul(self.mask3, x) x = torch.tanh(x) return x def main(): n = net() print(n) out = n(torch.randn(5, 10)) print(out) loss = torch.mean(out - torch.randn(1, 5)) opt_n = torch.optim.SGD(n.parameters(), lr=0.0001) loss.backward() opt_n.step() if __name__ == "__main__": main()
Here is my network structure and masks are binary. The trainable masks are supposed to block some neuron connections along with the training process.
It seems that after step, masks are not updated with their grads, why would this happen?
And if I want to train mask and layer123 separately, how am I supposed to remove layer123 and masks accordingly? Or is there a better implementation?