I am trying to reimplement mask-training as described in the paper Are Neural Nets Modular? Inspecting… Csordas et Al (2021) and I am having trouble understanding why the logits are not changing as a result of the optimiser.
Preface: Paper suggest to fully freeze the network prior to mask training. Calculate the Gumbel-Sigmoid and straight-through estimator to binarise the mask. And backprob target loss to logits.
Issues: The logits don’t seem to change at all, they remain fixed at their initialised 0.9 values through training (loss curves essentially fixed). What should I be doing to actually train the mask?
(Sidebar, if I uncomment the regularise term in loss function the loss will tend linearly towars -inf as a result of logits never changing)
Code below (it probably has a bunch of errors & redundancies - bare with me thanks):
'''Initialise logits & define loss and Optimiser'''
logits = []
for layer in handler.network.layers[0]:
if isinstance(layer, torch.nn.Linear):
logits.append(torch.nn.Parameter(data=torch.full_like(layer.weight.clone(), 0.9), requires_grad=True))
criterion = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(logits, lr=0.01)
'''Initialise hyper-parameters'''
NUM_EPOCHS = 20000 # NB: check for number of training epochs in paper
tau = 1 # temperature parameter, NB: check for value in paper
alpha = 0.0001/128 # regularisation parameter, NB: check for value in paper
'''Mask Training'''
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir + network_name, checkpoint=checkpoint, use_optimiser=False) # dont not to use network optimiser
handler.refreeze_weights()
for e in range(NUM_EPOCHS):
print(f'Starting epoch {e}...')
'''Reload frozen weights'''
handler.network.load_save()
handler.refreeze_weights()
'''Sampling and generating masks.'''
U1 = torch.rand(1, requires_grad=True).to(handler.network.device)
U2 = torch.rand(1, requires_grad=True).to(handler.network.device)
'''Gumbel Sigmoid & Straight through'''
samples = []
for layer in logits:
samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))
binaries_stop = []
for layer in samples:
with torch.no_grad():
binaries_stop.append((layer > 0.5).float() - layer)
binaries = []
for idx, layer in enumerate(binaries_stop):
binaries.append(layer + samples[idx])
'''Apply mask to weights'''
idx = 0
for layer in handler.network.layers[0]:
if isinstance(layer, torch.nn.Linear):
layer.weight = torch.nn.parameter.Parameter(layer.weight * binaries[idx])
idx += 1
'''Inference with masked network and backpropagation.'''
batch = next(iterator_train)
with torch.no_grad():
inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch]) # input
otp = torch.stack([b[2] for b in batch]) # output
ops = torch.stack([b[3] for b in batch]) # arithmetic operation
# Convert batch data toone-hot representation
inp, otp_ = handler.set_batched_digits(inp, otp, ops)
inp_ = inp.to(handler.network.device)
otp_ = otp_.to(handler.network.device)
otp_pred = handler.network(inp_) # fwd pass
#all_logits = alpha*torch.cat([layer.clone().detach().view(-1) for layer in logits]).to(handler.network.device) # regulariser - can ignore for now
optimiser.zero_grad()
loss = criterion(otp_pred, otp_).to(handler.network.device) # + torch.sum(all_logits)
loss.requires_grad_(requires_grad=True)
loss.backward()
optimiser.step()
if e % 100 == 0:
print('Saving Mask...')
torch.save(logits, 'masks/trained_logits_add_mask_.pt')
EDIT: Testig this on a large FNN (5 layer, size 2000)