Hello everyone,
I’m working with GPT-2 and I need to create a trainable mask for each weight layer. In other words, for each weight layer W, there is a matrix M of the same shape such that during the forward pass, a thresholding function is applied to M in order to create a mask, which is then applied to W. During the backward pass, the matrix M would need to be updated, in order to better sample the appropriate masks. Also, the backward function for the thresholding function would need to be a straight-through estimator, so that the gradients aren’t 0.
I tried this approach but self.pre_mask
(the matrix M) doesn’t get updated during the backward pass:
class ThresholdFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 0.5 is the threshold
return (input > 0.5).float()
@staticmethod
def backward(ctx, grad_output):
# Treats the threshold function like an identity function
return grad_output
class StraightThroughEstimator(nn.Module):
def __init__():
super(StraightThroughEstimator, self).__init__()
def forward(self, x):
x = ThresholdFunction.apply(x)
return x
class Masked(nn.Module):
def __init__(self, orig_layer):
super().__init__()
# the matrix W
self.orig_layer = orig_layer
# the matrix M
self.pre_mask = torch.rand(self.orig_layer.weight.size(), requires_grad=True)
self.masking_layer = StraightThroughEstimator()
def forward(self, *x):
mask = nn.Parameter(self.masking_layer.forward(self.pre_mask))
masked_layer = self.orig_layer
masked_layer.weight = nn.Parameter(self.orig_layer.weight * mask)
output = masked_layer(*x)
return output
class GPT2_with_mask(nn.Module):
def __init__():
super().__init__()
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
# Example of modifying a GPT-2 layer so that it has a trainable mask
self.model.transformer.h[0].attn.c_attn = Masked(self.model.transformer.h[i].attn.c_attn)
I’m also aware that using hooks might be another approach, but I’m not sure how to best proceed.
If it helps, this is the paper that I’m trying to implement: https://arxiv.org/pdf/2004.12406.pdf
Any inputs or suggestions would be really appreciated, thank you.