Hello all, I created a simple network where a convolutional layers weight matrix is altered by a custom function.
I came up with this :
class snet(nn.Module):
def __init__(self, num_classes=3):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 2, 1, 0)
shape = self.conv1.weight.shape
self.var1 = nn.Parameter(torch.ones(shape))
self.var2 = nn.Parameter(torch.ones(shape))
self.conv2 = nn.Conv2d(6, 6, 5, 1, 0)
self.fc = nn.Linear(6*11*11, num_classes)
def some_method(self):
"""
Suppose this is a custom method, tasked with
producing values for each entry in the weight matrix
of a convolutional layer. for simplicity we used
addition here
"""
return self.var1 + self.var2
def forward(self, input):
self.conv1.weight = nn.Parameter(self.some_method())
output = self.conv1(input)
# or using the functional api and using the weight matrix directly is no different
# output = F.conv2d(input, self.some_method())
output = self.conv2(output)
output = output.view(input.size(0), -1)
output = self.fc(output)
return output
n = snet(num_classes=3)
fake_dataset = torchvision.datasets.FakeData(100,
image_size=(3, 16, 16),
num_classes=3,
transform=transforms.ToTensor())
fake_dataloader = torch.utils.data.DataLoader(fake_dataset,
batch_size=20)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(n.parameters(), lr=0.01)
for imgs, labels in fake_dataloader:
p = n(imgs)
loss = criterion(p, labels)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
Apparently this is wrong as nothing happens! the parameters are added to the module and they show up in the parameters list. however, the gradient is always zero!
I noticed the grad_fn
property for both variables/parameters are None! where as it must have been the addition right?
Based on the autograd tutorial, when one variable in an operation has requires_grad = True
, the output also will have itsrequires_grad = True
and thus the gradient should flow back to those withrequires_grad
set to True.!
Since the nn.Parameter()
sets this property implicitly to True, this should work yet it does not!
whats wrong here? what am I missing here?
Any help is greatly appreciated.