# Learning to transform CNN filter weights

Hi, I have the following problem for which I can not find a solution:

There are two models. A pre-trained CNN, which can be assumed to be fixed. And another fully connected model to be trained, which transforms the filter weights of the CNN based on a parameter.

For example, a possible learned transformation could be a rotation of the filters by a certain angle. This way, the transformation model could transform the CNN such that it might be capable of classifying rotated images.

My problem is that I do not manage to get gradients for the transformation model. I guess when assigning the parameter weights, the graph is broken. Since the problem is probably easier to understand with code, here is a small illustrative example:

``````import torch
import torch.nn as nn
import copy

# define models
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
return self.conv(x)

class TransformationNet(nn.Module):
def __init__(self):
super(TransformationNet, self).__init__()
self.linear = nn.Linear(10, 9, bias=False)

def forward(self, x, param):
x = torch.cat([x,param], dim=1)
return self.linear(x)

model_cnn = CNN()
# Since we are changing the CNN filter weights while training, the copy is used to keep the original pre-trained model.
model_cnn_copy = CNN()

model_transformation = TransformationNet()
model_transformation.train()

# only optimize the model for transforming the filters
optimizer_model_transformation = torch.optim.SGD(model_transformation.parameters(), lr=0.1, momentum=0.9)

# training loop
for i in range(2):
B = 1 # batch size
image = torch.ones((B,1,3,3)) # some input

# update the CNN using the model for transforming the filters. The second argument torch.tensor([[1.0]]) is also just chosen arbitrarily for this example.
model_cnn.conv.weight.data = model_transformation(model_cnn_copy.conv.weight.data.view(B, -1), torch.tensor([[1.0]])).view((B,1,3,3))

out = model_cnn(image)

# just some arbitrary loss
loss = out.sum()

loss.backward()
``````model_cnn.conv.weight.data = model_transformation(model_cnn_copy.conv.weight.data.view(B, -1), torch.tensor([[1.0]])).view((B,1,3,3))