Creating trainable masks for model parameters

Hello everyone,

I’m working with GPT-2 and I need to create a trainable mask for each weight layer. In other words, for each weight layer W, there is a matrix M of the same shape such that during the forward pass, a thresholding function is applied to M in order to create a mask, which is then applied to W. During the backward pass, the matrix M would need to be updated, in order to better sample the appropriate masks. Also, the backward function for the thresholding function would need to be a straight-through estimator, so that the gradients aren’t 0.

I tried this approach but self.pre_mask (the matrix M) doesn’t get updated during the backward pass:

class ThresholdFunction(torch.autograd.Function):

    def forward(ctx, input):
        # 0.5 is the threshold
        return (input > 0.5).float()

    def backward(ctx, grad_output):
        # Treats the threshold function like an identity function
        return grad_output

class StraightThroughEstimator(nn.Module):
    def __init__():
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = ThresholdFunction.apply(x)
        return x

class Masked(nn.Module):

    def __init__(self, orig_layer):
        # the matrix W 
        self.orig_layer = orig_layer
        # the matrix M
        self.pre_mask = torch.rand(self.orig_layer.weight.size(), requires_grad=True)
        self.masking_layer = StraightThroughEstimator()

    def forward(self, *x):

        mask = nn.Parameter(self.masking_layer.forward(self.pre_mask))
        masked_layer = self.orig_layer
        masked_layer.weight = nn.Parameter(self.orig_layer.weight * mask)
        output = masked_layer(*x)

        return output

class GPT2_with_mask(nn.Module):

    def __init__():
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')
        # Example of modifying a GPT-2 layer so that it has a trainable mask
        self.model.transformer.h[0].attn.c_attn = Masked(self.model.transformer.h[i].attn.c_attn)

I’m also aware that using hooks might be another approach, but I’m not sure how to best proceed.

If it helps, this is the paper that I’m trying to implement:

Any inputs or suggestions would be really appreciated, thank you.