For the first, how about doing it similar to this:

```
class PartiallyFixedEmbedding(torch.nn.Module):
def __init__(self, fixed_weights, num_to_learn):
super().__init__()
self.num_fixed = fixed_weights.size(0)
self.num_to_learn = num_to_learn
weight = torch.empty(self.num_fixed+num_to_learn, fixed_weights.size(1))
weight[:self.num_fixed] = fixed_weights
self.trainable_weight= torch.nn.Parameter(torch.empty(num_to_learn, fixed_weights.size(1)))
torch.nn.init.kaiming_uniform_(self.trainable_weight)
weight[self.num_fixed:] = self.trainable_weight
self.register_buffer('weight', weight)
def forward(self, inp):
self.weight.detach_()
self.weight[self.num_fixed:] = self.trainable_weight
return torch.nn.functional.embedding(
inp, self.weight, None, None, 2.0, False, False)
```

Now the `fixed_weights`

bits in `weight`

won’t be trained, but the `trainable_weight`

will be.

You do have two copies of the trainable weights and copy them on forward, but the fixed weights will just sit there.

This can be elaborated into taking more parameters of the standard embedding layer, but I left it out now. I will run into trouble if you use the embedding twice in a single forward (because it does the detach).

The second option I (all too tersely) tried to describe is something like

```
class PartiallyFixedEmbedding2(torch.nn.Module):
def __init__(self, fixed_weights, num_to_learn):
super().__init__()
self.num_fixed = fixed_weights.size(0)
self.num_to_learn = num_to_learn
weight = torch.empty(self.num_fixed+num_to_learn, fixed_weights.size(1))
weight[:self.num_fixed] = fixed_weights
torch.nn.init.kaiming_uniform_(weight[self.num_fixed:])
self.weight= torch.nn.Parameter(weight)
self.register_buffer('learnable_mask', (torch.arange(self.num_fixed + num_to_learn).unsqueeze(1) >= self.num_fixed).float()) # could be more flexible
def forward(self, inp):
def zero_grad_fixed(gr):
return gr*self.learnable_mask
self.weight.register_hook(zero_grad_fixed)
return torch.nn.functional.embedding(
inp, self.weight, None, None, 2.0, False, False)
```

Here the `zero_grad_fixed`

zeros the gradient. It does so on the gradient of the weights which, for large embeddings, isn’t efficient. In that case you can, however, use the same technique to compute the mask on `inp`

and add the hook on the output of the embedding call. That way you’d multiply the gradient of the outputs of the embedding.

Best regards

Thomas