Say I have 2 independently trained models (with identical architecture) with parameters params1 and params2. I’d like to find out if there exists real values w1 and w2 s.t. the model with parameters (w1 x params1 + w2 x params2) / 2 performs well on some validation set.
To test this, I’ve written the following piece of code.
Can you share your vector_to_parameters function please?
What most likely happens is that you put that aggr_params back into an nn.Parameter which cannot have history and so it breaks the gradient graph.
To fix this, the simplest solution is to delete the nn.Parameter (this is not a parameter anymore as you don’t want to learn it) and replace it with a regular Tensor containing the value you want.
Wow, I didn’t knew we had that. That is a very very very very dangerous function We need to fix that.
The issue is slightly different here then: it uses .data (that should never be used!) and so break the computational graph. I’m afraid you won’t be able to use this function if you want gradient flowing back throw this update.
It’s not going to be very clean. But the following should work:
import torch
from torch import nn
model = nn.Sequential(nn.Linear(2, 2))
# Save the location of the parameters now
# Since we delete them later, we won't be able to call this function anymore
model_params = list(model.named_parameters())
agreg = torch.rand(6, requires_grad=True)
def get_last_module(model, indices):
mod = model
for idx in indices:
mod = getattr(mod, idx)
return mod
def replace_weights(agreg, model, model_params):
pointer = 0
for name, p in model_params:
indices = name.split(".")
mod = get_last_module(model, indices[:-1])
p_name = indices[-1]
if isinstance(p, nn.Parameter):
# We can override Tensors just fine, only nn.Parameters have custom logic
delattr(mod, p_name)
num_param = p.numel()
setattr(mod, p_name, agreg[pointer:pointer + num_param].view_as(p))
pointer += num_param
print(model)
print(model[0].weight)
replace_weights(agreg, model, model_params)
print(model[0].weight)
agreg = agreg * 10
replace_weights(agreg, model, model_params)
print(model[0].weight)