Hi, I have the following problem for which I can not find a solution:
There are two models. A pre-trained CNN, which can be assumed to be fixed. And another fully connected model to be trained, which transforms the filter weights of the CNN based on a parameter.
For example, a possible learned transformation could be a rotation of the filters by a certain angle. This way, the transformation model could transform the CNN such that it might be capable of classifying rotated images.
My problem is that I do not manage to get gradients for the transformation model. I guess when assigning the parameter weights, the graph is broken. Since the problem is probably easier to understand with code, here is a small illustrative example:
import torch
import torch.nn as nn
import copy
# define models
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
class TransformationNet(nn.Module):
def __init__(self):
super(TransformationNet, self).__init__()
self.linear = nn.Linear(10, 9, bias=False)
def forward(self, x, param):
x = torch.cat([x,param], dim=1)
return self.linear(x)
model_cnn = CNN()
# Since we are changing the CNN filter weights while training, the copy is used to keep the original pre-trained model.
model_cnn_copy = CNN()
model_cnn_copy.load_state_dict(copy.deepcopy(model_cnn.state_dict()))
model_transformation = TransformationNet()
model_transformation.train()
# only optimize the model for transforming the filters
optimizer_model_transformation = torch.optim.SGD(model_transformation.parameters(), lr=0.1, momentum=0.9)
# training loop
for i in range(2):
B = 1 # batch size
image = torch.ones((B,1,3,3)) # some input
# update the CNN using the model for transforming the filters. The second argument torch.tensor([[1.0]]) is also just chosen arbitrarily for this example.
model_cnn.conv.weight.data = model_transformation(model_cnn_copy.conv.weight.data.view(B, -1), torch.tensor([[1.0]])).view((B,1,3,3))
out = model_cnn(image)
# just some arbitrary loss
loss = out.sum()
optimizer_model_transformation.zero_grad()
loss.backward()
print(model_transformation.linear.weight.grad)
optimizer_model_transformation.step()
print(loss.item())
The gradient of the linear layer is None and the loss stays constant. I assume the line
model_cnn.conv.weight.data = model_transformation(model_cnn_copy.conv.weight.data.view(B, -1), torch.tensor([[1.0]])).view((B,1,3,3))
causes the problem.
Thank you very much for any hints or ideas!