Get gradient of loss wrt step-length for step-length search?

I try to implement a customized optimizer that performs step-length search.
It requires computing, $\phi(a) = net(w + a * p)$ and $d_{\phi}/d_a$,
where

  • $a$ is the step-length,
  • $w$ is the weight vector,
  • $p$ is the step-direction vector,
  • net is the neural network function that output a scalar loss value, $L$

I can compute $d_{\phi}/d_a$, which is equivalent to $d_L/d_a$, for simple net below.

The question is:
For a net constructed using nn.module,

  • how to compute the gradient of loss wrt step-length?
  • how to update the net parameter so that we can compute the gradient of the loss wrt to step-length?
#!/usr/bin/env python3
import torch
torch.manual_seed(12345)

N, D_in, H, D_out = 10, 2, 5, 1

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

w1 = torch.randn(D_in, H, requires_grad=True)
w2 = torch.randn(H, D_out, requires_grad=True)

p1 = torch.randn(D_in, H, requires_grad=True)
p2 = torch.randn(H, D_out, requires_grad=True)

a = torch.randn(1, requires_grad=True)

w1 = torch.randn(D_in, H, requires_grad=True)
w2 = torch.randn(H, D_out, requires_grad=True)

w1 = w1 + a * p1
w2 = w2 + a * p2

h = x.mm(w1)
h_relu = h.clamp(min=0)

y_pred = h_relu.mm(w2)
loss = (y_pred - y).pow(2).mean()

# Compute ga: grad of loss wrt a
ga, = torch.autograd.grad(loss, a, create_graph=True)
print(ga)

# Net using nn.module #####
class Net(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(D_in, H)
        self.output = torch.nn.Linear(H, D_out)

    def forward(self, x):
        y = torch.nn.functional.relu( self.hidden(x) )
        y = self.output(y)
        return y

net = Net(2, 5, 1)
loss_fn = torch.nn.MSELoss()

p1 = torch.transpose(p1, 0, 1)
p2 = torch.transpose(p2, 0, 1)

for name, p in net.named_parameters():
    # Got RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
    # if name == 'hidden.weight':
    #     p.add_(a * p1)
    # elif name == 'output.weight':
    #     p.add_(a * p2)
    # else:
    #     pass

    # RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavio
    # if name == 'hidden.weight':
    #     p = p + (a * p1)
    # elif name == 'output.weight':
    #     p = p + (a * p2)
    # else:
    #     pass

    # TODO: how to update the weight so that we can compute the gradient of loss wrt step-length a?
    pass

y_pred = net(x)
loss = loss_fn(y_pred, y)
print(loss.item())

# TODO: Compute ga: grad of loss wrt a
# ga, = torch.autograd.grad(loss, a, create_graph=True)
# print(ga)

Is your first approach working?
You can find a similar example here. Could you try to use the parameter updates shown in the code?

@ptrblck: thanks for the link.

Yes, I was following that tutorial, and end up with the following that so far, work as expected.
I would be grateful if you could comment on it, thank you :slightly_smiling_face:

class BareNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BareNet, self).__init__()

        def _init_linear_layer(w, b):
            # Following:
            # https://github.com/pytorch/pytorch/blob/769cb5a6405b39a0678e6bc4f2d6fea62e0d3f12/torch/nn/modules/linear.py#L48
            stdv = 1. / math.sqrt(w.size(1))
            w.data.uniform_(-stdv, stdv)
            b.data.uniform_(-stdv, stdv)

        self.hidden_w = torch.nn.Parameter(torch.zeros(hidden_dim, input_dim))
        self.hidden_b = torch.nn.Parameter(torch.zeros(hidden_dim))
        self.hidden_pw = torch.nn.Parameter(torch.zeros(input_dim, hidden_dim))
        self.hidden_pb = torch.nn.Parameter(torch.zeros(hidden_dim))
        self.output_w = torch.nn.Parameter(torch.zeros(input_dim, hidden_dim))
        self.output_b = torch.nn.Parameter(torch.zeros(output_dim))
        self.output_pw = torch.nn.Parameter(torch.zeros(hidden_dim, input_dim))
        self.output_pb = torch.nn.Parameter(torch.zeros(output_dim))
        self.alpha = torch.nn.Parameter(torch.zeros(1))

        self.w_params = torch.nn.ParameterList([self.hidden_w, self.output_w])
        self.b_params = torch.nn.ParameterList([self.hidden_b, self.output_b])
        self.pw_params = torch.nn.ParameterList([self.hidden_pw, self.output_pw])
        self.pb_params = torch.nn.ParameterList([self.hidden_pb, self.output_pb])
        self.a_params = torch.nn.ParameterList([self.alpha])
        self.wb_params = [w for w in self.w_params] + [b for b in self.b_params]

        _init_linear_layer(self.hidden_w, self.hidden_b)
        _init_linear_layer(self.output_w, self.output_b)

    def forward(self, x):
        hidden_w = self.hidden_w.transpose(0, 1) + (self.alpha * self.hidden_pw)
        hidden_b = self.hidden_b + (self.alpha * self.hidden_pb)
        output_w = self.output_w.transpose(0, 1) + (self.alpha * self.output_pw)
        output_b = self.output_b + (self.alpha * self.output_pb)
        y = torch.nn.functional.relu( x.mm(hidden_w) + hidden_b )
        y = y.mm(output_w) + output_b
        return y

And in my_optim.py:

...

        def _phi(alpha):
            # Update a_, pw_, pb_ params
            for p in a_params: p.data.fill_(alpha)
            for i, p in enumerate(pw_params): p.data.copy_(w_step_dirs[i].transpose(0, 1))
            for i, p in enumerate(pb_params): p.data.copy_(b_step_dirs[i])

            # Get loss and its grad
            loss = closure(do_backward=False)
            grad_alpha, = torch.autograd.grad(loss, a_params, create_graph=False)

            # Zero a_, pw_, pb_ params
            for p in a_params: p.data.fill_(0.0)
            for p in pw_params: p.data.copy_(torch.zeros_like(p))
            for p in pb_params: p.data.copy_(torch.zeros_like(p))

            return (loss.item(), grad_alpha.item())
...