Optimizing a mask instead of weights

Hi,

I am trying to optimize the mask applied to a pre-trained model instead of the weights of the model in order to reproduce this paper : https://arxiv.org/pdf/2405.10989

For now I am trying with a simple random mask for each weight, my code is the following :

class MaskedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.masks = torch.nn.ParameterDict()

        # Create a learnable mask per weight
        for name, param in model.named_parameters():
            self.masks[name.replace(".", " ")] = torch.nn.Parameter(torch.rand(param.shape, requires_grad=True, dtype=param.dtype, device=self.model.device))

        with torch.no_grad():
            self.original_params = {name: p.clone() for name, p in self.model.named_parameters()}

    def forward(self, *args, **kwargs):
        # Apply mask
        for name, param in self.model.named_parameters():
            mask_name = name.replace(".", " ")
            if mask_name in self.masks:
                param.mul_(self.masks[mask_name])


        return self.model(*args, **kwargs)

        
    def restore_original_model(self):
        # Restore original weights
        with torch.no_grad():
            for name, p in self.model.named_parameters():
                p.data.copy_(self.original_params[name])


def mask_neuron_activation_algorithm(data_loader, model, device, num_steps=1, lr=1e-4):
    for p in model.parameters():
        p.requires_grad = False
    
    masked_model = MaskedModel(model)
    optimizer = torch.optim.SGD(masked_model.masks.parameters(), lr=lr)

    for step in range(num_steps):

        for batch in tqdm(data_loader, desc="Finding privacy neurons"):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = masked_model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            print(loss)
            loss.backward()
            optimizer.step()
            masked_model.restore_original_model()

I get this error on the second batch of training :

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I do not understand which part of the graph is not modified again during the second forward. I do not find people having the same type of problem apart from Optimizing a mask for weights instead of weights themselves in which the answer is to use the param._mul function (which I already do).

I also tried to pass the masked_model.model.parameters() in the optimizer but it gives the same results.

Does anyone understand what is going on ?

Hi!

I think the issue could be that your masked_model.model.parameters() will preserve their grad_fn even after the call to restore_original_model. This is because you just restore the data of these tensors, instead of restoring them entirely.

I would try to:

  1. freeze the parameters of masked_model.model (set all their requires_grad to False, since you don’t want to optimize them anyway).
  2. change restore_original_model so that it restores the parameters entirely (not just the .data field).
  3. stop using in-place modification of the model’s parameters (in param.mul_(...)) and thus get rid of restore_original_model. This should improve performance and be much easier to get done without bugs. For that, you could try to use torch.func.functional_call so that you can call your model with the masked params instead of calling it with its stored params.

I recommend doing 1 and 3.

If this doesn’t solve your issue, please provide a minimal runnable example (with the imports, model, dataloader, etc) so that I can run your function and see what’s wrong.

1 Like

Thank you very much for the suggestion it does work like this :

def mask_neuron_activation_algorithm(data_loader, model, device, num_steps=1, lr=1e-4, temperature=0.025, gamma=-0.1, zeta=1.1, eta=5):

    original_parameters = {}
    mask_parameters = {}

    # Create a learnable mask per weight
    for name, param in model.named_parameters():
        mask_parameters[name] = torch.rand(param.shape, requires_grad=True, dtype=param.dtype, device=device)
        original_parameters[name] = param.data.clone().detach().to(device)

    optimizer = torch.optim.SGD(mask_parameters.values(), lr=lr)

    for step in range(num_steps):

         for batch in tqdm(data_loader, desc="Finding pirvacy neurons"):

            batch_params = {}
            for name, param in model.named_parameters():
                batch_params[name] = mask_parameters[name]*original_parameters[name]

            batch = {k: batch[k].to(device) for k in batch}
            batch["labels"] = batch["input_ids"]


            outputs = torch.func.functional_call(model, batch_params, kwargs=batch)
            loss = outputs.loss
            print(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            del batch_params
            torch.cuda.empty_cache()
1 Like