I’m trying to implement HyperNetwork. The idea is that we have a hypernetwork A and it produces the weight of network B. We use B to perform classification and backpropogate through B to update the weights of A. The issue with vanilla pytorch implementation is that A’s weight is usually nn.Module() or nn.Parameter() and they are leaf nodes. Autodiff will stop at leaf nodes.
It seems to me that make_functional() in functorch is a workaround: it returns a functional version of your network and their params. See the following code:
class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.nc = 8 self.conv1 = nn.Conv2d(1, self.nc, (3, 3), bias=False) self.conv2 = nn.Conv2d(self.nc, self.nc, (3, 3), bias=False) self.conv3 = nn.Conv2d(self.nc, self.nc, (3, 3), bias=False) self.fc = nn.Linear(484*self.nc, 10, bias=False) nn.init.normal_(self.conv1.weight) nn.init.normal_(self.conv2.weight) nn.init.normal_(self.conv3.weight) nn.init.normal_(self.fc.weight) def forward(self, x): x = self.conv1(x)/(torch.sqrt(torch.tensor(1*3*3))) x = x.relu() x = self.conv2(x)/(torch.sqrt(torch.tensor(self.nc*3*3))) x = x.relu() x = self.conv3(x)/(torch.sqrt(torch.tensor(self.nc*3*3))) x = x.flatten(1) x = self.fc(x)/(torch.sqrt(torch.tensor(self.nc*484))) return x net = CNN().to(device) fnet, params = make_functional(net) x = torch.rand(1,1,28,28).cuda() A = nn.Linear(1, 72) temp = A(torch.rand(1)) params.data = A(torch.rand(1)).view(8,1,3,3).cuda() o=fnet(params, x) o.sum().backward() print(params.grad) print(A.weight.grad)
However, A.weight.grad is still None. Is this expected?