How to optimise mask on a fully frozen network?

I am trying to reimplement mask-training as described in the paper Are Neural Nets Modular? Inspecting… Csordas et Al (2021) and I am having trouble understanding why the logits are not changing as a result of the optimiser.

Preface: Paper suggest to fully freeze the network prior to mask training. Calculate the Gumbel-Sigmoid and straight-through estimator to binarise the mask. And backprob target loss to logits.

Issues: The logits don’t seem to change at all, they remain fixed at their initialised 0.9 values through training (loss curves essentially fixed). What should I be doing to actually train the mask?
(Sidebar, if I uncomment the regularise term in loss function the loss will tend linearly towars -inf as a result of logits never changing)

Code below (it probably has a bunch of errors & redundancies - bare with me thanks):

'''Initialise logits & define loss and Optimiser'''
logits = []
for layer in handler.network.layers[0]:
    if isinstance(layer, torch.nn.Linear):
        logits.append(torch.nn.Parameter(data=torch.full_like(layer.weight.clone(), 0.9), requires_grad=True))

criterion = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(logits, lr=0.01)

'''Initialise hyper-parameters'''
NUM_EPOCHS = 20000  # NB: check for number of training epochs in paper
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.0001/128  # regularisation parameter, NB: check for value in paper

'''Mask Training'''
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir + network_name, checkpoint=checkpoint, use_optimiser=False) # dont not to use network optimiser
handler.refreeze_weights()

for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')
    
    '''Reload frozen weights'''
    handler.network.load_save()
    handler.refreeze_weights()
    
    '''Sampling and generating masks.'''
    U1 = torch.rand(1, requires_grad=True).to(handler.network.device)
    U2 = torch.rand(1, requires_grad=True).to(handler.network.device)
    
    '''Gumbel Sigmoid & Straight through'''
    samples = []
    for layer in logits:       
        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []        
    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)
        
    binaries = []
    for idx, layer in enumerate(binaries_stop):
        binaries.append(layer + samples[idx])
   
'''Apply mask to weights'''
    idx = 0
    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight = torch.nn.parameter.Parameter(layer.weight * binaries[idx])
            idx += 1

    '''Inference with masked network and backpropagation.'''
    batch = next(iterator_train)
    with torch.no_grad():
        inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch]) # input
        otp = torch.stack([b[2] for b in batch]) # output
        ops = torch.stack([b[3] for b in batch]) # arithmetic operation
        # Convert batch data toone-hot representation
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)
        inp_ = inp.to(handler.network.device)
        otp_ = otp_.to(handler.network.device)

        otp_pred = handler.network(inp_) # fwd pass

        #all_logits = alpha*torch.cat([layer.clone().detach().view(-1) for layer in logits]).to(handler.network.device) # regulariser - can ignore for now
    optimiser.zero_grad()
    loss = criterion(otp_pred, otp_).to(handler.network.device)  # + torch.sum(all_logits)
    loss.requires_grad_(requires_grad=True)
    loss.backward()
    optimiser.step()

    if e % 100 == 0:
        print('Saving Mask...')
        torch.save(logits, 'masks/trained_logits_add_mask_.pt') 

EDIT: Testig this on a large FNN (5 layer, size 2000)

I fixed the problem but I am going to leave this here as there aren’t many relevant discussions on this so may be useful.

The first issue lies it subtle disconects from computational graph, for example when using variables such as binaries[idx]. The fetching operator isn’t algorithmic in any sense, so you can’t compute gradients as it’s non-differentiable.

The second issue was; considering we wanted to train our mask on a fully frozen network and frozen (parameters of) networks should not have tracked gradient associated with them, we can’t just directly apply the mask to the frozen network or forward the input as we did during pre-training. In a sense, this is also a non-differentiable operation, though perhaps its more like “gradient-blocking” as the frozen parameters don’t have gradient so it detaches the operation from the computational graph.

The third issue is then ofcourse, what to do? Instead of passing the the batch through some frozen layer directly, we can pass it through an identical layer which does have gradient associated with it, but is detached from the pretrained network. Code bellow shows an example for the simple FNN the inspection method was intended for: (Notice, layer.weight is frozen so wont accumulate gradients, while the mask iterator is a leaf/leads to a leaf node, which can accumulate gradient on the backward pass - does it without the indexing problem from the first issue) (Also, layer-shape should be contained in the mask size and parameter sizes being the same, ofcourse if these are different, the bias won’t match up and the obvious not training on an identical network).

for layer in handler.network.layers[0]:
    if isinstance(layer, torch.nn.Linear): '''layer specific to FNN'''
        weight = layer.weight *  next(terator_mask)
        x= F.linear(x, weight, layer.bias)
    else:
        x= layer(x)

Hopefully this is helpful to someone.

This is quite helpful for me. Surprisingly, very limited information is present on this topic even in this community. Hats off to you for explaining it in great detail.