Combining select parameters from two models during training

Hi,

I’m training two UNet MONAI models with the same architecture while combining their parameters while training. In particular, after a certain number of epochs, I want to combine all but the last 4 layers of the models and average them (In this case all layers start without the name “model.2” ). A part of the script that I use is -

max_epochs = 1000
val_interval = 2

best_metric = -1

best_metric_1 = -1
best_metric_2 = -1

best_metric_epoch = -1
best_metric_epoch_1 = -1
best_metric_epoch_2 = -1

epoch_loss_values = []
metric_values = []

epoch_loss_values_1 = []
metric_values_1 = []

epoch_loss_values_2 = []
metric_values_2 = []

post_pred_1 = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label_1 = Compose([AsDiscrete(to_onehot=2)])

post_pred_2 = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label_2 = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
  epoch_loss_1 = 0
  epoch_loss_2 = 0

  step_0 = 0
  step_1 = 0
  
  # For one epoch
  
  print("-" * 10)
  print(f"epoch {epoch + 1}/{max_epochs}")
  
  model_1.train()
  model_2.train()
  
  # One forward pass of the 1 data through the 1 UNet
  for batch_data_1 in train_loader_1:
      step_0 += 1
      inputs_1, labels_1 = (
          batch_data_1["image"].to(device),
          batch_data_1["label"].to(device),
      )
      optimizer_1.zero_grad()
      outputs_1 = model_1(inputs_1)
      
      loss_1 = loss_function_1(outputs_1, labels_1)
      loss_1.backward()
      
      optimizer_1.step()
      epoch_loss_1 += loss_1.item()
      print(
          f"{step_0}/{len(train_ds_1) // train_loader_1.batch_size}, "
          f"train_loss: {loss_1.item():.4f}")
      wandb.log({"train/loss 1": loss_1.item()})
  epoch_loss_1 /= step_0
  epoch_loss_values_1.append(epoch_loss_1)
  print(f"epoch {epoch + 1} average loss 1: {epoch_loss_1:.4f}")


  # One forward pass of the 2 data through the 2 UNet
  for batch_data_2 in train_loader_2:
      step_1 += 1
      inputs_2, labels_2 = (
          batch_data_2["image"].to(device),
          batch_data_2["label"].to(device),
      )
      optimizer_2.zero_grad()
      outputs_2 = model_2(inputs_2)
      loss_2 = loss_function_2(outputs_2, labels_2)
      loss_2.backward()
      optimizer_2.step()
      epoch_loss_2 += loss_2.item()
      print(
          f"{step_1}/{len(train_ds_2) // train_loader_2.batch_size}, "
          f"train_loss: {loss_2.item():.4f}")
      wandb.log({"train/loss 2": loss_2.item()})
  epoch_loss_2 /= step_1
  epoch_loss_values_2.append(epoch_loss_2)
  print(f"epoch {epoch + 1} average loss 2: {epoch_loss_2:.4f}")
  
  scheduler_1.step()
  scheduler_2.step()

  # Store weights before aggregation strategy
  
  if epoch % 10 == 0:
    meta_model = []
    weights_2 = []
    weights_1 = []
    # Aggregate weights
    for name, param in model_2.named_parameters():
      if not "model.2" in name:
        weights_2.append(param)

    for name, param in model_1.named_parameters():
      if not "model.2" in name:
        weights_1.append(param)

    for weight_1, weight_2 in zip(weights_1,weights_2):
      meta_model.append((weight_1 + weight_2)/2)   # Change aggregation strategy

    for index_old, param_old in enumerate(model_1.parameters()):
      for index_new, param_new in enumerate(meta_model):
        param_old = param_new

    for index_old, param_old in enumerate(model_2.parameters()):
      for index_new, param_new in enumerate(meta_model):
        param_old = param_new

However, when I compare the two final trained models together using

def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print('Models match perfectly!')

I see that the two models are different in terms of weights for all the layers whereas only the last 4 layers should be different. How can I fix this?

I would be careful about enforcing that all state dict items are equal unless e.g., you are also sure that the running stats for normalization operations are also combined across the models (or that they saw identical data in an identical order during training). Could you check which keys are mismatching and if they correspond to tensors other than those for the weights?

Thanks @eqy! I beleive my code was combining the running stats for the norm layers too.
I get the mismatch on all the keys as follows -

Mismtach found at model.0.conv.unit0.conv.weight
Mismtach found at model.0.conv.unit0.conv.bias
Mismtach found at model.0.conv.unit0.adn.N.weight
Mismtach found at model.0.conv.unit0.adn.N.bias
Mismtach found at model.0.conv.unit0.adn.N.running_mean
Mismtach found at model.0.conv.unit0.adn.N.running_var
Mismtach found at model.0.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.0.conv.unit0.adn.A.weight
Mismtach found at model.0.conv.unit1.conv.weight
Mismtach found at model.0.conv.unit1.conv.bias
Mismtach found at model.0.conv.unit1.adn.N.weight
Mismtach found at model.0.conv.unit1.adn.N.bias
Mismtach found at model.0.conv.unit1.adn.N.running_mean
Mismtach found at model.0.conv.unit1.adn.N.running_var
Mismtach found at model.0.conv.unit1.adn.N.num_batches_tracked
Mismtach found at model.0.conv.unit1.adn.A.weight
Mismtach found at model.0.residual.weight
Mismtach found at model.0.residual.bias
Mismtach found at model.1.submodule.0.conv.unit0.conv.weight
Mismtach found at model.1.submodule.0.conv.unit0.conv.bias
Mismtach found at model.1.submodule.0.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.0.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.0.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.0.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.0.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.0.conv.unit0.adn.A.weight
Mismtach found at model.1.submodule.0.conv.unit1.conv.weight
Mismtach found at model.1.submodule.0.conv.unit1.conv.bias
Mismtach found at model.1.submodule.0.conv.unit1.adn.N.weight
Mismtach found at model.1.submodule.0.conv.unit1.adn.N.bias
Mismtach found at model.1.submodule.0.conv.unit1.adn.N.running_mean
Mismtach found at model.1.submodule.0.conv.unit1.adn.N.running_var
Mismtach found at model.1.submodule.0.conv.unit1.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.0.conv.unit1.adn.A.weight
Mismtach found at model.1.submodule.0.residual.weight
Mismtach found at model.1.submodule.0.residual.bias
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.conv.weight
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.conv.bias
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.0.conv.unit0.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.conv.weight
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.conv.bias
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.0.conv.unit1.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.0.residual.weight
Mismtach found at model.1.submodule.1.submodule.0.residual.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.conv.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.conv.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.residual.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.0.residual.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.conv.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.conv.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.conv.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.conv.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.residual.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.1.submodule.residual.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.conv.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.conv.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.0.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.conv.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.conv.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.2.0.conv.weight
Mismtach found at model.1.submodule.1.submodule.2.0.conv.bias
Mismtach found at model.1.submodule.1.submodule.2.0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.2.0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.2.0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.2.0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.2.0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.2.0.adn.A.weight
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.conv.weight
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.conv.bias
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.1.submodule.2.1.conv.unit0.adn.A.weight
Mismtach found at model.1.submodule.2.0.conv.weight
Mismtach found at model.1.submodule.2.0.conv.bias
Mismtach found at model.1.submodule.2.0.adn.N.weight
Mismtach found at model.1.submodule.2.0.adn.N.bias
Mismtach found at model.1.submodule.2.0.adn.N.running_mean
Mismtach found at model.1.submodule.2.0.adn.N.running_var
Mismtach found at model.1.submodule.2.0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.2.0.adn.A.weight
Mismtach found at model.1.submodule.2.1.conv.unit0.conv.weight
Mismtach found at model.1.submodule.2.1.conv.unit0.conv.bias
Mismtach found at model.1.submodule.2.1.conv.unit0.adn.N.weight
Mismtach found at model.1.submodule.2.1.conv.unit0.adn.N.bias
Mismtach found at model.1.submodule.2.1.conv.unit0.adn.N.running_mean
Mismtach found at model.1.submodule.2.1.conv.unit0.adn.N.running_var
Mismtach found at model.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked
Mismtach found at model.1.submodule.2.1.conv.unit0.adn.A.weight
Mismtach found at model.2.0.conv.weight
Mismtach found at model.2.0.conv.bias
Mismtach found at model.2.0.adn.N.weight
Mismtach found at model.2.0.adn.N.bias
Mismtach found at model.2.0.adn.N.running_mean
Mismtach found at model.2.0.adn.N.running_var
Mismtach found at model.2.0.adn.N.num_batches_tracked
Mismtach found at model.2.0.adn.A.weight
Mismtach found at model.2.1.conv.unit0.conv.weight
Mismtach found at model.2.1.conv.unit0.conv.bias

@ptrblck Could really use some advice from you on this, Thanks!