Higher order: gradient of optimization procedure of whole nn.Module

Hi, I’m trying to differentiate through an optimization procedure of another neural net. The code below shows 2 functions:

  1. a working version where instead of neural nets we just have single tensors
  2. the version with neural nets with 3 attempted solutions commented out.

Both should be equivalent.
I’m a newbie so I welcome explanations of what I’m actually doing in those failed attempts.

For context, I’m implementing the LOLA paper where agents in games model each other’s learning: https://arxiv.org/pdf/1709.04326.pdf

import torch

gd_iter = 5
lr = 0.1
A = torch.tensor([[-2., -3.], [0., -1.]])


class SelfOutputNet(torch.nn.Module):
   """Net that just returns it's own parameter vector"""
    def __init__(self):
        super(SelfOutputNet, self).__init__()
        self.theta = torch.nn.Parameter(torch.zeros(2, requires_grad=True))
    def forward(self):
        return self.theta


def test1_tensor_working():
   """Uses tensors"""
    theta1 = torch.zeros(2, requires_grad=True)
    theta2 = torch.zeros(2, requires_grad=True)
    theta1_optimizer = torch.optim.SGD((theta1,), lr=lr)

    def fake_objective_func(x1, x2):
        """Prisoner's dilemma loss"""
        x1, x2 = torch.softmax(x1, 0), torch.softmax(x2, 0)
        return x1.view(1, -1).mm(A).view(-1).dot(x2)

    # Inner optimization
    for k in range(gd_iter):
        fake_objective2 = fake_objective_func(theta2, theta1)
        grad2 = torch.autograd.grad(fake_objective2, theta2, create_graph=True)[0]
        theta2 = theta2 - lr * grad2  # Produces the result it should with my true objective func
   
   # Outer optimization
    fake_objective1 = fake_objective_func(theta1, theta2)
    theta1_optimizer.zero_grad()
    fake_objective1.backward()
    theta1_optimizer.step()


def test2_NN_not_working():
   """Equivalent test using nn.Modules"""
    net1 = SelfOutputNet()
    net2 = SelfOutputNet()
    theta1_optimizer = torch.optim.SGD(net1.parameters(), lr=lr)

    def fake_objective_func(x1, x2):
        """Prisoner's dilemma loss"""
        out1, out2 = x1.forward(), x2.forward()
        x1, x2 = torch.softmax(out1, 0), torch.softmax(out2, 0)
        return - x1.view(1, -1).mm(A).view(-1).dot(x2)

    # Inner optimization
    for k in range(gd_iter):
        fake_objective2 = fake_objective_func(net2, net1)
        grad2 = torch.autograd.grad(fake_objective2, net2.parameters(), create_graph=True)
        for i, layer in enumerate(net2.parameters()):
            """Three failing solutions:"""
            # layer.sub_(lr * grad2[i])       # RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
            # layer.data.sub_(lr * grad2[i])  # Has no effect as measured by final outcome when using my true objective func
            # layer = lr * grad2[i]           # Has no effect

   # Outer optimization
    fake_objective1 = fake_objective_func(net1, net2)
    theta1_optimizer.zero_grad()
    fake_objective1.backward()
    theta1_optimizer.step()

test_tensor_working()
test_NN_not_working()```

Add with torch.no_grad() right before the weight update:

with torch.no_grad():
    for i, layer in enumerate(net2.parameters()):
        layer.sub_(lr * grad2[i])

Thanks ptrblck! This does remove the RuntimeError. But the two functions still aren’t leading to the same result, which they should do. I’ve verified in a few ways that introducing the 2-parameter nn.Module instead of an equivalent 2-parameter tensor causes the difference. Only the result for the tensor version is correct. I can also see that the solution you offered does at least change something compared to not doing the inner optimization at all. Maybe no_grad() removes the ability to differentiate through the opponent’s learning? Any other solutions?

When I remove the minus sign in the return statement in test2_NN_not_working.fake_objective_function theta1 and theta2 have the same values in both runs.
I assume it’s a typo or do you want to negate the operation in the second case?

Thanks for the quick response, yes that’s a typo. The outcome is only the same here because it uses a fake objective function (prisoner’s dilemma). For the fake one, differentiating through the opponent shouldn’t do anything anyway. The true objective is an iterated prisoner’s dilemma. I can share the whole code if that helps.

As far as I understand your examples are now working both in the same way, but apparently are not doing, what you want them to?
Sure, share the code and explain a bot more about your use case.

Sure give me a moment to make it shareable. My use case is this: player 1 wants to take a gradient step; models that player 2 will take one in response (inner optimization); player 1 wants to take the step such that, after player 2’s response, player 1’s loss will be low. This outer objective is given in eq 4.2 (left side) in the linked paper. I suspect no_grad() doesn’t work because it stops the influence of net1 on fake_objective2 which in turn affects the update net2 makes.

Regarding your last question: The two examples aren’t working the same way but they should be. They only gave you the same outcome (modulo the minus) because I uploaded code with a fake objective function that causes this.

Can you execute this code? I’ve stuffed it all in one script. It plots the result.

The hyperparameter hp.tensors toggles between test1 and test2, ie. tensors and Modules. You shouldn’t need to understand the various things I defined after the tests.

Again the two tests are equivalent except one uses a Module instead of a tensor. They implement the use case I described.

import matplotlib.pyplot as plt
from copy import deepcopy
from torch.distributions import Bernoulli


class HyParams():
    def __init__(self):
        'Choosing tensors vs nn.Modules'
        self.tensors = False

        self.lr_out = 0.2    # default: 0.2
        self.lr_in = 0.3     # default: 0.3
        # self.optim_algo = torch.optim.Adam
        self.optim_algo = torch.optim.SGD
        self.gamma = 0.96
        self.n_outer_opt = 70
        self.n_inner_opt = 2
        self.len_rollout = 10
        self.batch_size = 64
        self.seed = 42

hp = HyParams()

class SelfOutputNet(torch.nn.Module):
    def __init__(self):
        super(SelfOutputNet, self).__init__()
        self.theta = torch.nn.Parameter(torch.zeros(hp.num_states, requires_grad=True))
        self.optimizer = torch.optim.SGD(self.parameters(), lr=hp.lr_out)
    def forward(self):
        return self.theta


def test1_tensors_working(n_inner_opt):
    print("start iterations with", n_inner_opt, "lookaheads:")
    joint_scores = []

    T1 = torch.zeros(hp.num_states, requires_grad=True)
    T2 = torch.zeros(hp.num_states, requires_grad=True)
    T1.optimizer = torch.optim.SGD((T1,), lr=hp.lr_out)
    T2.optimizer = torch.optim.SGD((T2,), lr=hp.lr_out)

    for update in range(hp.n_outer_opt):
        def LOLA_step(T1, T2_):

            # Inner optimization
            for k in range(n_inner_opt):
                true_objective2 = game_tensor.true_objective(T2_, T1)
                grad2 = torch.autograd.grad(true_objective2, (T2_,), create_graph=True)[0]
                T2_ = T2_ - hp.lr_in * grad2

            # Outer optimization
            true_objective1 = game_tensor.true_objective(T1, T2_)
            T1.optimizer.zero_grad()
            true_objective1.backward()
            T1.optimizer.step()

        T2_ = deepcopy(T2)
        T1_ = deepcopy(T1)
        LOLA_step(T1, T2_)
        LOLA_step(T2, T1_)

        joint_scores = eval_and_print(joint_scores, update, T1, T2)

    return joint_scores


def test2_NN_not_working(n_inner_opt):
    print("start iterations with", n_inner_opt, "lookaheads:")
    joint_scores = []

    net1 = SelfOutputNet()
    net2 = SelfOutputNet()

    for update in range(hp.n_outer_opt):
        def LOLA_step(net1, net2_):

            # Inner optimization
            for k in range(n_inner_opt):
                true_objective2 = game_NN.true_objective(net2_, net1)
                grad2 = torch.autograd.grad(true_objective2, net2_.parameters(), create_graph=True)
                with torch.no_grad():
                    for i, layer in enumerate(net2_.parameters()):
                        layer.sub_(hp.lr_in * grad2[i])

            # Outer optimization
            true_objective1 = game_NN.true_objective(net1, net2_)
            net1.optimizer.zero_grad()
            true_objective1.backward()
            net1.optimizer.step()

        net2_ = deepcopy(net2)
        net1_ = deepcopy(net1)
        LOLA_step(net1, net2_)
        LOLA_step(net2, net1_)

        joint_scores = eval_and_print(joint_scores, update, net1, net2)

    return joint_scores








"""
========================================VARIOUS HELPER FUNCTIONS AND CLASSES ===========================================
"""
def eval_func_NN(input):
    return input.forward()
def eval_func_tensor(input):
    return input





def act(batch_states, theta):

    batch_states = torch.from_numpy(batch_states).long()
    probs = torch.sigmoid(eval_func(theta))[batch_states]
    m = Bernoulli(1-probs)
    actions = m.sample()
    log_probs_actions = m.log_prob(actions)
    return actions.numpy().astype(int), log_probs_actions


def eval_and_print(joint_scores, update, agent1, agent2):
    # evaluate:
    score = step(agent1, agent2)
    joint_scores.append(0.5 * (score[0] + score[1]))

    # print
    if update % 10 == 0:
        p1 = [p.item() for p in torch.sigmoid(eval_func(agent1))]
        p2 = [p.item() for p in torch.sigmoid(eval_func(agent2))]
        print('update', update, 'score (%.3f,%.3f)' % (score[0], score[1]), 'policy 1:', p1, 'policy 2:', p2)
    return joint_scores


def step(theta1, theta2):
    # just to evaluate progress:
    (s1, s2), _ = game.reset()
    score1 = 0
    score2 = 0
    for t in range(hp.len_rollout):
        a1, lp1 = act(s1, theta1)
        a2, lp2 = act(s2, theta2)
        (s1, s2), (r1, r2),_,_ = game.step((a1, a2))
        # cumulate scores
        score1 += np.mean(r1)/float(hp.len_rollout)
        score2 += np.mean(r2)/float(hp.len_rollout)
    return (score1, score2)


'import OneHot'
import gym
import numpy as np

from gym.spaces import prng


class OneHot(gym.Space):
    """
    One-hot space. Used as the observation space.
    """
    def __init__(self, n):
        self.n = n

    def sample(self):
        return prng.np_random.multinomial(1, [1. / self.n] * self.n)



"""
Iterated Prisoner's dilemma environment.
"""
from gym.spaces import Discrete, Tuple

class IteratedPrisonersDilemma(gym.Env):
    """
    A two-agent vectorized environment.
    Possible actions for each agent are (C)ooperate and (D)efect.
    """
    # Possible actions
    NUM_AGENTS = 2
    NUM_ACTIONS = 2
    NUM_STATES = 5

    def __init__(self, max_steps, gamma, eval_func, batch_size=1):
        self.max_steps = max_steps
        self.batch_size = batch_size
        self.payout_mat = np.array([[-2,0],[-3,-1]])
        self.states = np.array([[1,2],[3,4]])
        self.gamma = gamma
        self.input = input

        self.action_space = Tuple([
            Discrete(self.NUM_ACTIONS) for _ in range(self.NUM_AGENTS)
        ])
        self.observation_space = Tuple([
            OneHot(self.NUM_STATES) for _ in range(self.NUM_AGENTS)
        ])
        self.available_actions = [
            np.ones((batch_size, self.NUM_ACTIONS), dtype=int)
            for _ in range(self.NUM_AGENTS)
        ]

        self.step_count = None
        self.eval_func = eval_func


    def reset(self):
        self.step_count = 0
        init_state = np.zeros(self.batch_size)
        observation = [init_state, init_state]
        info = [{'available_actions': aa} for aa in self.available_actions]
        return observation, info

    def true_objective(self, theta1, theta2):
        """Differentiable objective in torch"""
        p1 = torch.sigmoid(self.eval_func(theta1))
        # p2 = torch.sigmoid(eval_func(theta2[[0,1,3,2,4]]))
        p2 = torch.sigmoid(self.eval_func(theta2))
        p0 = (p1[0], p2[0])
        p = (p1[1:], p2[1:])
        # create initial laws, transition matrix and rewards:
        def phi(x1, x2):
            return [x1 * x2, x1 * (1 - x2), (1 - x1) * x2, (1 - x1) * (1 - x2)]
        P0 = torch.stack(phi(*p0), dim=0).view(1,-1)
        P = torch.stack(phi(*p), dim=1)
        R = torch.from_numpy(self.payout_mat).view(-1,1).float()
        # the true value to optimize:
        objective = (P0.mm(torch.inverse(torch.eye(4) - self.gamma*P))).mm(R)
        return -objective

    def step(self, action):
        ac0, ac1 = action
        self.step_count += 1

        r0 = self.payout_mat[ac0, ac1]
        r1 = self.payout_mat[ac1, ac0]
        s0 = self.states[ac0, ac1]
        s1 = self.states[ac1, ac0]
        observation = [s0, s1]
        reward = [r0, r1]
        done = (self.step_count == self.max_steps)
        info = [{'available_actions': aa} for aa in self.available_actions]
        return observation, reward, done, info

    def contains(self, x):
        return isinstance(x, np.ndarray) and \
               x.shape == (self.n, ) and \
               np.all(np.logical_or(x == 0, x == 1)) and \
               np.sum(x) == 1

    @property
    def shape(self):
        return (self.n, )

    def __repr__(self):
        return "OneHot(%d)" % self.n

    def __eq__(self, other):
        return self.n == other.n


'Create game env'
if hp.tensors:
    game_tensor, hp.num_states = IteratedPrisonersDilemma(hp.len_rollout, hp.gamma, eval_func_tensor, hp.batch_size), 5
    game = game_tensor
    eval_func = eval_func_tensor
else:
    game_NN, hp.num_states = IteratedPrisonersDilemma(hp.len_rollout, hp.gamma, eval_func_NN, hp.batch_size), 5
    game = game_NN
    eval_func = eval_func_NN



# plot progress:
if __name__=="__main__":

    colors = ['b','c','m','r']

    for i in range(1, hp.n_inner_opt):
        torch.manual_seed(hp.seed)
        scores = np.array(test1_tensors_working(i)) if hp.tensors else np.array(test2_NN_not_working(i))
        plt.plot(scores, colors[i], label=str(i)+" lookaheads")

    plt.legend()
    plt.xlabel('grad steps')
    plt.ylabel('joint score')
    plt.show(block=True)

I found a solution: Updating the Module._parameters dictionary directly:

# Inner optimization
for k in range(n_inner_opt):
      true_objective2 = game_NN.true_objective(net2_, net1)
      grad2 = torch.autograd.grad(true_objective2, net2_.parameters(), create_graph=True)
      'In the following lines, I needed to update the ._parameters dictionary while not using with.no_grad()'
      for i, (layer_name, layer) in enumerate(net2_._parameters.items()):
            layer = layer - hp.lr_in * grad2[i]
            net2_._parameters[layer_name] = layer

Turns out my solution was not general enough. It only works for parameters that are directly created in the init function as nn.Parameter, but not for parameters that belong to a Module such as nn.Linear. Those aren’t in net._parameters. I can’t think of a gereral solution that isn’t very cumbersome. Any more tips?

These parameters should also be there.
However, it’s usually not recommended to use internal methods (starting with an underscore).
Could you try to get your parameters using model.named_parameters()?

model = nn.Linear(10, 10)
print(dict(model.named_parameters())['weight'])

That one does contain all parameters whereas model._parameters is empty:

>>>dict(net1.named_parameters())

{'fc2.bias': Parameter containing:
tensor([-0.1455,  0.3597], requires_grad=True), 'fc1.weight': Parameter containing:
tensor([[ 0.5153],
        [-0.4414]], requires_grad=True), 'fc3.weight': Parameter containing:
tensor([[ 0.0983, -0.0866],
        [ 0.1961,  0.0349],
        [ 0.2583, -0.2756],
        [-0.0516, -0.0637],
        [ 0.1025, -0.0028]], requires_grad=True), 'fc2.weight': Parameter containing:
tensor([[-0.1371,  0.3319],
        [-0.6657,  0.4241]], requires_grad=True), 'fc3.bias': Parameter containing:
tensor([ 0.6181,  0.2200, -0.2633, -0.4271, -0.1185], requires_grad=True)}

>>>dict(net1._parameters)

{}

Is there maybe a way to change a model’s parameters for my case without directly modifying model._parameters?

I believe the problem @SoerenMind was experiencing is that ._parameters does not access model parameters when defined using a class. Is there a solution to that?