Keeping the Computation Graph Connected for Interpolating Model Parameters

Hi, I am trying to implement the optimization procedure in “Loss Surfaces, Mode Connectivity and Fast Ensembling of DNNs” paper by Garipov et al. and I am struggling with getting the optimization to work. There is lot’s of model parameter assignment and I was hoping someone could help me with it.

The main idea is the following (in somewhat pseudocode) - suppose I have two trained models, model1 and model2, and they have parameters w_1 = model1.parameters and w_2 = model2.parameters. I can piecewise linearly interpolate these models by the following function, with respect to a middle point $p$ that is the same dimension as the parameters. This interpolation of the parameters goes as, for $t \in [0,1]$

if $t \le 1/2$ then $w_{new} = (1-2t) * w_1 + 2t * p$
if $t > 1/2$ then $w_{new} = (1-2t) * p + 2t * w_2$

this is just interpolation. Given $w_1, w_2$ fixed I want to optimize for p, such that the loss along this interpolation is minimized. I hope this is clear.

My code is as follows:

Piecewise Interpolate

This takes two torch.nn.Parameter as input, and returns a tensor that is the interpolated value at some t. I don’t think this is where the computation gets detached.

def piecewise_interpolate(params1, params2, p, alpha):
    if alpha <= 0.5:
        output = (1 - 2 * alpha) * params1 + (2 * alpha) * p
    else:
        output = (1 - 2 * (alpha - 0.5)) * p + (2 * (alpha - 0.5)) * params2
    return output

Interpolation Model

I believe a mistake might be here, assigning the parameters in the __init__ or forward. I have tried so many different ways but it did not work.

In __init__, I initialize a model, that for each layer, has the average of both model’s parameters as the initial weight. Here, the average is the trainable parameters, and I do not want any of the weights of model1 or model2 to be trained.

In forward, for a given “middle point” p, and $t$ that indicates where in the interpolation we will evaluate the model, I compute the weight at $t$ and put the weights into a new “interpolated_model”, for which we evaluate the forward function. Here I want to assign the parameters in such a way that, once we compute the loss with the interpolated model and call the optimizer, the computation graph can find the parameter p and update that, and not actually update the parameters of the interpolated model.

class InterpolationModel(torch.nn.Module):

  def __init__(self, device, model1, model2):
    super(InterpolationModel, self).__init__()
    self.new_model = SimpleMLP().to(device)
    self.model1 = model1
    self.model2 = model2
    self.model1_state_dict = model1.state_dict()
    self.model2_state_dict = model2.state_dict()


    for layer_name in list(self.new_model.state_dict()):
      model1_layer_param = dict(self.model1.named_parameters())[layer_name]
      model2_layer_param = dict(self.model2.named_parameters())[layer_name]
      new_layer_param = torch.nn.Parameter( (model1_layer_param + model2_layer_param) / 2 )

      layer_name_found = False
      for name, param in self.new_model.named_parameters():
        if name == layer_name:
          weights = new_layer_param
          layer_name_found = True
          break
          
      if not layer_name_found:
        raise ValueError(f'Could not find any parameters named {layer_name} - fix the code!')

  

  def forward(self, t, x):
    no_training_model = SimpleMLP().to(device)
    for layer_name in list(no_training_model.state_dict()):
      layer_name_found = False
      param1 = dict(self.model1.named_parameters())[layer_name]
      param2 = dict(self.model2.named_parameters())[layer_name]
      weights = dict(self.new_model.named_parameters())[layer_name]

      new_param = piecewise_interpolate(param1, param2, weights, t)
      interpolated_model_layer_parameter = dict(no_training_model.named_parameters())[layer_name]
      interpolated_model_layer_parameter.data = new_param

    output = no_training_model(x)
    return output

Training Procedure

Below is how I train the model. I just pick a random t (as done in the paper), evaluate the loss on that t and update the parameters accordingly. What I find, unfortunately is that the param.grad gives None in that for loop I call. All of those parameters are in Leaf indeed, so no problem with that.

# Training function
# We input the test_loader of the ORIGINAL models. Because we want to find the minimum loss path in the TEST loss
def train_interpolation_model(model, test_loader, criterion, optimizer, epochs=1):
    counter = 0
    model.train()
    for epoch in range(epochs):
        # This gives an approximate average loss, to see model trains. It is printed at the end of every epoch.
        epoch_losses = []
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            t = torch.rand(1).to(device)
            output = model(t, data)
            loss = criterion(output, target)
            loss.backward()

            # For debugging purposes - after loss.backward() the param should not have "None" loss.
            # print(loss.grad_fn)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    # print(f'Layer Name: {name}')
                    # print(f'Parameter Grad: {param.grad}')
                    # print(f'Parameter Is Leaf (must be True - grad is accumulated on leafs): {param.is_leaf}')
                    # print("============")
                    break

            optimizer.step()
            epoch_losses.append(loss)

            # if counter % 100 == 0:
            #   print(f'Loss at step {counter}: {loss}')
            # counter += 1

        print(f'Epoch {epoch + 1} / {epochs} done! Epoch Average Loss: {sum(epoch_losses) / len(epoch_losses)}')

# CALLING THE TRAINING
interpolation_model = InterpolationModel(device, model1, model2).to(device)
data_loader = test_loader
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(interpolation_model.parameters(), lr=0.01)
epochs = 10

train_interpolation_model(model=interpolation_model,
                          test_loader=data_loader,
                          criterion=criterion,
                          optimizer=optimizer,
                          epochs=epochs)

Does the parameters get detached at some point? At what point? I tried a bunch of different methods to assign yet no luck, I’ve been spending past 3-4 days on this and couldn’t manage to do it so far. I would appreciate any help.

What methods have you tried? Generally how you debug these issues is to bisect the sequence of ops from the parameter you need gradients for to the output, and try to observe at which point the tensors outputs have t.grad_fn is None. If a tensor output has a valid grad_fn that means the operation performed is part of the connected graph.