Use parametrizations to take gradient of only part of tensor?

I want to train only part of a parameter tensor, i.e. have part of it fixed training. This seems a not uncommon problem (e.g. here and here).

One way to achieve this is to set the gradient for those fixed parts to before taking a step at training. Another solution is to have 2 tensors, a trainable and a fixed one and concatenate them when needed. Both solutions imply changing class usage w.r.t. standard case to achieve the goal.

An alternative solution I came up with is to use parametrizations to construct the parameter tensor by putting together a trainable and a non-trainable component (a more “elegant” version of solution 2 above). There is some work in defining the parametrization, but then class usage is the same as in the standard case (no fixed parameters).

I’m just looking for feedback on whether this is too convoluted, or a good solution to a common problem. See below an example where filter is a vector and we use this approach to fix the first dim_fixed elements to be 1.

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

# Problem dimensions
N_DIM = 5
N_DIM_FIXED = 2

# Model class that projects input to a vector
class Projection(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.filter = nn.Parameter(torch.randn(dim))

    def forward(self, x):
        return torch.einsum("ij,j->i", x, self.filter)

# Make data with high variance in last dimension
x = torch.randn(1000, N_DIM)
x[:, -1] *= 10

# Initialize model
model = Projection(N_DIM)

####### CODE ADDED TO FIX ELEMENTS OF F #######

# Class that puts together trainable and fixed parts of a tensor
class PartialFixedTensor(nn.Module):
    def __init__(self, fixed_tensor):
        super().__init__()
        # Register the fixed part of the tensor
        self.register_buffer("fixed_tensor", fixed_tensor)

    def forward(self, X):
        # Concatenate the fixed tensor with the modifiable tensor
        return torch.cat([self.fixed_tensor, X/torch.norm(X)])

    def right_inverse(self, X):
        # Get dimensions of the fixed tensor
        n_fixed = self.fixed_tensor.shape[0]
        # Return the modifiable part of the tensor,
        # to initialize trainable part
        return X[n_fixed:]/torch.norm(X[n_fixed:])

# Add parametrization to f
fixed_tensor = torch.ones(N_DIM_FIXED)
parametrize.register_parametrization(
  model, "filter", PartialFixedTensor(fixed_tensor)
)

####### END OF CODE ADDED TO FIX ELEMENTS OF F #######

# Print initial f
print(model.filter.detach())

# Optimize model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for i in range(1000):
    optimizer.zero_grad()
    y = model(x)
    loss = -torch.var(y)
    loss.backward()
    optimizer.step()

# First elements of f are unchanged
print(model.filter.detach())