How to train model in subspace of parameter space

Hello,

class Subspace_model(nn.Module):
    """
    Wraps a model in order to train it in a subspace
    """
    def __init__(self, model, E):
        super(Subspace_model, self).__init__()
        self.model = model
        self.register_buffer("E", E)
        self.d_dim = E.shape[1]
        self.params_d = nn.Parameter(torch.zeros(self.d_dim))
        self.register_buffer("params_0", parameters_to_vector(model.parameters()))

    def forward(self, x):
        params_D = self.E @ self.params_d + self.params_0
        vector_to_parameters(params_D, self.model.parameters())
        return self.model.forward(x)

I want to train a model in an affine-linear subspace of the full-parameter space. For that, I created an embedding E that takes vectors from R^d to R^D. Here, d is the dimensionality of the affine-linear subspace, D is the dimensionality of the full parameter space.
params_0 are the initial D-dimensional parameters of the model.
Originally, my d-dimensional parameter-vector params_d is the zero vector. It is what I try to learn.

A forward pass works like this: First, I do E @ params_d + params_0 in order to get the new parameters of the model. Then I apply it to the input.
For the backward pass, I want to do the backward of the model in order to get the gradients with respect to the D-dimensional weight vector. Then I want to backpropagate this further through the matrix E in order to update the d-dimensional Vector params_d. params_d is the only vector I want to be updated in the optimizer step!

I think my code doesn’t work, and I’m not quite sure why. params_d doesn’t receive any gradient.

Potential reasons:

  • Maybe I need to register self.model somehow as a buffer as well, since I don’t want it to be updated. However, pytorch doesn’t allow that.
  • Maybe the gradient is not able to pass through the “vector_to_parameters” step. If so, how could this be solved?

Best regards,
LL

Hi,

Could you show us the vector_to_parameters function please?

Hi albanD,

it is this one:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/convert_parameters.html

Here is the source code:

def vector_to_parameters(vec, parameters):
    r"""Convert one vector to the parameters

    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError('expected torch.Tensor, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None

    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view_as(param).data

        # Increment the pointer
        pointer += num_param

As you can see, this function is using .data which means that it break the computational graph. This is not expected to work with differentiable operations.

nn.Parameters() are built only to be trainable parameters, not intermediary Tensors that require gradients.
One way to go around this is to manually delete the old parameters and just put your new weights there:

del mod.weight # Remove it from the Parameter list
mod.weight = your_new weights
# mod.weight is not just a Tensor, not a Parameter anymore.

Hi Leon!

I don’t really follow your embedding scheme, and I can’t comment
on your code.

But if I understand goal correctly, perhaps you can project your
gradients onto your desired subspace after an unmodified
backward() step.

The common optimizers use some variant of gradient descent
where params -= learning_rate * gradient (where gradient
might have some momentum history in it, but this doesn’t change
the idea).

So just do the same update step, but replace gradient with a
projected_gradient that lies in your subspace. (You would
have tweak or write your own optimizer to do this.)

Alternatively, you could perform the standard update (and not have
to modify the optimizer), and then project your new parameters to
lie in your subspace. (These two approaches are equivalent*, but I
think in terms of projecting the gradient rather than the parameters
for some reason.)

*) “Project gradient” and “project (updated) parameters” are equivalent
if your subspace is understood to contain the zero vector. From your
stated initial condition, I understand this to be the case.

Best.

K. Frank

Dear albanD and KFrank,

thanks for your answers!

@KFrank: The idea to change the optimizer (in my case SGD) worked in my case.

In order to do so, I split up the matrix E into several components, one for each array in the parameters of the model. Then, I could manually backpropagate the gradients of all the arrays through the embedding. I then re-embedded this into the full parameter space.
Mathematically, that corresponds to multiplying the gradient with E*E.T, which is probably what you meant by “projection”.

In case you’re interested: I’m reproducing the results from this paper on the intrinsic dimension of objective landscapes: https://arxiv.org/abs/1804.08838

Best regards,
Leon