Learning to transform CNN filter weights

Hi, I have the following problem for which I can not find a solution:

There are two models. A pre-trained CNN, which can be assumed to be fixed. And another fully connected model to be trained, which transforms the filter weights of the CNN based on a parameter.

For example, a possible learned transformation could be a rotation of the filters by a certain angle. This way, the transformation model could transform the CNN such that it might be capable of classifying rotated images.

My problem is that I do not manage to get gradients for the transformation model. I guess when assigning the parameter weights, the graph is broken. Since the problem is probably easier to understand with code, here is a small illustrative example:

import torch
import torch.nn as nn
import copy

# define models
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

class TransformationNet(nn.Module):
    def __init__(self):
        super(TransformationNet, self).__init__()
        self.linear = nn.Linear(10, 9, bias=False)

    def forward(self, x, param):
        x = torch.cat([x,param], dim=1)
        return self.linear(x)
model_cnn = CNN()
# Since we are changing the CNN filter weights while training, the copy is used to keep the original pre-trained model. 
model_cnn_copy = CNN()

model_transformation = TransformationNet()

# only optimize the model for transforming the filters
optimizer_model_transformation = torch.optim.SGD(model_transformation.parameters(), lr=0.1, momentum=0.9)

# training loop
for i in range(2):
    B = 1 # batch size
    image = torch.ones((B,1,3,3)) # some input    
    # update the CNN using the model for transforming the filters. The second argument torch.tensor([[1.0]]) is also just chosen arbitrarily for this example.
    model_cnn.conv.weight.data = model_transformation(model_cnn_copy.conv.weight.data.view(B, -1), torch.tensor([[1.0]])).view((B,1,3,3))
    out = model_cnn(image)
    # just some arbitrary loss
    loss = out.sum()

The gradient of the linear layer is None and the loss stays constant. I assume the line

model_cnn.conv.weight.data = model_transformation(model_cnn_copy.conv.weight.data.view(B, -1), torch.tensor([[1.0]])).view((B,1,3,3))

causes the problem.

Thank you very much for any hints or ideas!