For the first, how about doing it similar to this:
class PartiallyFixedEmbedding(torch.nn.Module):
def __init__(self, fixed_weights, num_to_learn):
super().__init__()
self.num_fixed = fixed_weights.size(0)
self.num_to_learn = num_to_learn
weight = torch.empty(self.num_fixed+num_to_learn, fixed_weights.size(1))
weight[:self.num_fixed] = fixed_weights
self.trainable_weight= torch.nn.Parameter(torch.empty(num_to_learn, fixed_weights.size(1)))
torch.nn.init.kaiming_uniform_(self.trainable_weight)
weight[self.num_fixed:] = self.trainable_weight
self.register_buffer('weight', weight)
def forward(self, inp):
self.weight.detach_()
self.weight[self.num_fixed:] = self.trainable_weight
return torch.nn.functional.embedding(
inp, self.weight, None, None, 2.0, False, False)
Now the fixed_weights
bits in weight
won’t be trained, but the trainable_weight
will be.
You do have two copies of the trainable weights and copy them on forward, but the fixed weights will just sit there.
This can be elaborated into taking more parameters of the standard embedding layer, but I left it out now. I will run into trouble if you use the embedding twice in a single forward (because it does the detach).
The second option I (all too tersely) tried to describe is something like
class PartiallyFixedEmbedding2(torch.nn.Module):
def __init__(self, fixed_weights, num_to_learn):
super().__init__()
self.num_fixed = fixed_weights.size(0)
self.num_to_learn = num_to_learn
weight = torch.empty(self.num_fixed+num_to_learn, fixed_weights.size(1))
weight[:self.num_fixed] = fixed_weights
torch.nn.init.kaiming_uniform_(weight[self.num_fixed:])
self.weight= torch.nn.Parameter(weight)
self.register_buffer('learnable_mask', (torch.arange(self.num_fixed + num_to_learn).unsqueeze(1) >= self.num_fixed).float()) # could be more flexible
def forward(self, inp):
def zero_grad_fixed(gr):
return gr*self.learnable_mask
self.weight.register_hook(zero_grad_fixed)
return torch.nn.functional.embedding(
inp, self.weight, None, None, 2.0, False, False)
Here the zero_grad_fixed
zeros the gradient. It does so on the gradient of the weights which, for large embeddings, isn’t efficient. In that case you can, however, use the same technique to compute the mask on inp
and add the hook on the output of the embedding call. That way you’d multiply the gradient of the outputs of the embedding.
Best regards
Thomas