Hi all,
I am trying to implement a convolutional neural network where one of the layers’ weights are constrained by a couple of parameters, which are optimized during backpropagation. Another way to put it: the weights themselves are not directly updated, but the parameters that determine the weight matrices should be. Here is a simplified version of the code:
import torch.nn.functional as F
import torch
from torch import nn
class GaborConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
):
super().__init__()
self.is_calculated = False
self.conv_layer = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
)
self.kernel_size = self.conv_layer.kernel_size
self.test = nn.Parameter((torch.Tensor([2.0])), requires_grad=True)
self.register_parameter("test", self.test)
def forward(self, input_tensor):
if self.training:
self.calculate_weights()
self.is_calculated = False
if not self.training:
if not self.is_calculated:
self.calculate_weights()
self.is_calculated = True
return self.conv_layer(input_tensor)
def calculate_weights(self):
self.conv_layer.weight.data += self.test
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GaborConv2d(1, 1, 3)
self.fc = nn.Linear(9, 3)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = F.relu(x)
return x
if __name__ == '__main__':
model = Net()
model.optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # only the weights and biases of the fully connected layer are learned
model.loss = nn.CrossEntropyLoss()
for i in range(100):
random_data = torch.rand((1, 1, 5, 5))
model.optimizer.zero_grad()
output = model(random_data)
l = model.loss(output, torch.tensor([0]))
l.backward(retain_graph=True)
model.optimizer.step()
The “test” variable is the parameter that controls the weights. In the optimizer, I do see test listed in the params group. However, its grad is None, and it is not updated during backpropagation. Any input would be appreciated! Thanks!