Hello all, I created a simple network where a convolutional layers weight matrix is altered by a custom function.
I came up with this :
class snet(nn.Module): def __init__(self, num_classes=3): super().__init__() self.conv1 = nn.Conv2d(3, 6, 2, 1, 0) shape = self.conv1.weight.shape self.var1 = nn.Parameter(torch.ones(shape)) self.var2 = nn.Parameter(torch.ones(shape)) self.conv2 = nn.Conv2d(6, 6, 5, 1, 0) self.fc = nn.Linear(6*11*11, num_classes) def some_method(self): """ Suppose this is a custom method, tasked with producing values for each entry in the weight matrix of a convolutional layer. for simplicity we used addition here """ return self.var1 + self.var2 def forward(self, input): self.conv1.weight = nn.Parameter(self.some_method()) output = self.conv1(input) # or using the functional api and using the weight matrix directly is no different # output = F.conv2d(input, self.some_method()) output = self.conv2(output) output = output.view(input.size(0), -1) output = self.fc(output) return output n = snet(num_classes=3) fake_dataset = torchvision.datasets.FakeData(100, image_size=(3, 16, 16), num_classes=3, transform=transforms.ToTensor()) fake_dataloader = torch.utils.data.DataLoader(fake_dataset, batch_size=20) criterion = nn.CrossEntropyLoss() opt = torch.optim.Adam(n.parameters(), lr=0.01) for imgs, labels in fake_dataloader: p = n(imgs) loss = criterion(p, labels) opt.zero_grad() loss.backward() opt.step() print(loss.item())
Apparently this is wrong as nothing happens! the parameters are added to the module and they show up in the parameters list. however, the gradient is always zero!
I noticed the
grad_fn property for both variables/parameters are None! where as it must have been the addition right?
Based on the autograd tutorial, when one variable in an operation has
requires_grad = True, the output also will have its
requires_grad = True and thus the gradient should flow back to those with
requires_grad set to True.!
nn.Parameter() sets this property implicitly to True, this should work yet it does not!
whats wrong here? what am I missing here?
Any help is greatly appreciated.