Combining two models

Say I have 2 independently trained models (with identical architecture) with parameters params1 and params2. I’d like to find out if there exists real values w1 and w2 s.t. the model with parameters (w1 x params1 + w2 x params2) / 2 performs well on some validation set.

To test this, I’ve written the following piece of code.

w = torch.ones(2, 1, requires_grad=True, device=args.device)
optimizer = optim.SGD([w], lr=1e-1, momentum=0.0)

for rnd in tqdm(range(1, args.epochs+1)):
    model.train()
    val_loss, val_acc = 0.0, 0.0 
    for _, (inputs, labels) in enumerate(val_loader):
        # pass inputs to device, clear gradients
        inputs, labels = inputs.to(device=args.device, non_blocking=True),\
                        labels.to(device=args.device, non_blocking=True)
        optimizer.zero_grad()
        
        # compute aggregated params
        aggr_params = (w[0]*model_1_params + w[1]*model_2_params) / 2
        # Load params to the model
        vector_to_parameters(aggr_params, model.parameters())
        # forward-backward pass 
        outputs = model(inputs)
        minibatch_loss = criterion(outputs, labels)
        minibatch_loss.backward()
        
        optimizer.step()
        
        # keep track of round loss/accuracy
        val_loss += minibatch_loss.item()*outputs.shape[0]
        _, pred_labels = torch.max(outputs, 1)
        val_acc += torch.sum(torch.eq(pred_labels.view(-1), labels)).item()
        
    # inference after epoch
    print(w)
    val_loss, val_acc = val_loss/len(val_dataset), val_acc/len(val_dataset)       
    print(f'| Valid Loss:{val_loss:.3f}|', end='--')
    print(f'| Valid Acc: {val_acc:.3f}|', end='\r')

However, this does not update weights w at all. I suspect the line

vector_to_parameters(aggr_params, model.parameters())

might be breaking autograd. However, I’m not sure.

Hi,

Can you share your vector_to_parameters function please?

What most likely happens is that you put that aggr_params back into an nn.Parameter which cannot have history and so it breaks the gradient graph.
To fix this, the simplest solution is to delete the nn.Parameter (this is not a parameter anymore as you don’t want to learn it) and replace it with a regular Tensor containing the value you want.

Hi,

vector_to_parameters is not my own implementation. It is a PyTorch function.

Wow, I didn’t knew we had that. That is a very very very very dangerous function :smiley: We need to fix that.

The issue is slightly different here then: it uses .data (that should never be used!) and so break the computational graph. I’m afraid you won’t be able to use this function if you want gradient flowing back throw this update.

I see, that’s ok I guess :slight_smile: In that case, how can I load the aggregated parameters to the model ?

It’s not going to be very clean. But the following should work:

import torch
from torch import nn


model = nn.Sequential(nn.Linear(2, 2))
# Save the location of the parameters now
# Since we delete them later, we won't be able to call this function anymore
model_params = list(model.named_parameters())

agreg = torch.rand(6, requires_grad=True)

def get_last_module(model, indices):
    mod = model
    for idx in indices:
        mod = getattr(mod, idx)
    return mod

def replace_weights(agreg, model, model_params):
    pointer = 0
    for name, p in model_params:
        indices = name.split(".")
        mod = get_last_module(model, indices[:-1])
        p_name = indices[-1]
        if isinstance(p, nn.Parameter):
            # We can override Tensors just fine, only nn.Parameters have custom logic
            delattr(mod, p_name)

        num_param = p.numel()
        setattr(mod, p_name, agreg[pointer:pointer + num_param].view_as(p))
        pointer += num_param


print(model)

print(model[0].weight)
replace_weights(agreg, model, model_params)
print(model[0].weight)

agreg = agreg * 10
replace_weights(agreg, model, model_params)
print(model[0].weight)
1 Like

Perfect, works like a charm. Thank you so much for your time.