How can we assign an external tensor to the nn.Parameters of a network and keep the gradient flowing out of the network back to the external source of the parameters? The problem is that when assigning the tensor to the weights of a layer, I can’t find a way to make autograd see that the assignment to the weight is a slice from a larger tensor of parameters.
Create parameters
model = nn.Sequential( nn.Linear(1, 1) ).to(device)
theta = torch.tensor([[1.0,2.0]], requires_grad=True).to(device)
Select slice of theta to be assign as weight
w = torch.reshape(theta[0][0:1], model._modules[‘0’]._parameters[‘weight’].shape)
This returns a tensor with the right grad_fn:
tensor([[1.]], device=‘cuda:0’, grad_fn=< AsStridedBackward >)
Assignment with nn.Parameter does not work since I think there is clone() when constructing
Since you won’t learn weight anymore, you don’t need it to be a parameter anymore. So you can do the following
mod = model._modules['0']
del mod.weight
mod.weight = w
You need to be careful though to delete .weight before it is passed to the optimizer.
And to populate it with a Tensor of the right size after that before calling the forward.
This solution seems to work so far. The idea is to first remove all Parameters in the Module, and replace them by tensors under object attributes with the same name. Here is what I used:
def flip_parameters_to_tensors(module):
attr = []
while bool(module._parameters):
attr.append( module._parameters.popitem() )
setattr(module, 'registered_parameters_name', [])
for i in attr:
setattr(module, i[0], torch.zeros(i[1].shape,requires_grad=True))
module.registered_parameters_name.append(i[0])
module_name = [k for k,v in module._modules.items()]
for name in module_name:
flip_parameters_to_tensors(module._modules[name])
Then, we can used the saved list of previously active attributes to assign the tensors:
def set_all_parameters(module, theta):
count = 0
for name in module.registered_parameters_name:
a = count
b = a + getattr(module, name).numel()
t = torch.reshape(theta[0,a:b], getattr(module, name).shape)
setattr(module, name, t)
count += getattr(module, name).numel()
module_name = [k for k,v in module._modules.items()]
for name in module_name:
count += set_all_parameters(module._modules[name], theta)
return count
This way, the flattened vector is assigned to all tensors of the NN for evaluation. The backward() gets back to the parameter vector outside the NN.