How can I replace A model's params and get the grad correctly?

I have a resnet model and I want to replace a Conv tensor parameter of it with my own tensor A and B.
Here is the code:

model = resnet.resnet18()

# my_param = torch.nn.Parameter(torch.zeros(size=(64, 64, 3, 3)))

A = torch.nn.Parameter(torch.zeros(size=(64 * 64, 32 * 32))) # A is my desired leaf Parameter
B = torch.nn.Parameter(torch.zeros(size=(32, 32, 3, 3))) # B is my desired leaf Parameter

B.requires_grad_(True)
B.retain_grad()

B = B.view(-1, 9)

base = A @ B # matrix mm to create 'base'

base.requires_grad_(True)
base.retain_grad()

base = base.view(64, 64, 3, 3)

A.requires_grad_(True)
A.retain_grad()
B.requires_grad_(True)
B.retain_grad()

model.layer1[0].conv1.weight = torch.nn.Parameter(base) # replace the params with 'base'

# my random img
img = {}
img["img"] = torch.randn(size=(1, 3, 28, 28))
img["label"] = torch.ones(1).long()

out = model(img)
print(out)

out["loss"].backward() # I already got a loss func, the loss will be auto calculed 

print(A.grad, B.grad) # print NONE

But, the result of A.grad and B.grad is NONE…

Oh my god, How can I eventually get their grads correctly?

I have try many ways:

  1. Use model.load_state_dict(), still NONE

  2. Use model.layer1[0].conv1.weight = base, but pytorch raise TypeError: cannot assign ‘torch.FloatTensor’ as parameter ‘weight’ (torch.nn.Parameter or None expected)

  3. Use model.layer1[0].conv1.weight.data = nn.Parameter(base), still NONE

That’s expected since you are creating a new nn.Parameter without a gradient history.
You could check this post to see if parametrize would work.