Understanding load_state_dict() effect on computational graph

Hi all,

Let’s say I have a network model and 2 weight files, {W1 and W2}.

The network is initialized with weights W1 and forward pass is performed and loss is computed.

Now, the network weights are switched over to W2 and call to backward() is performed.

Will this compute gradients w.r.t to W2 or W1?

Most likely neither. Since gradient computation often use both input and saved intermediate & final output. In this case they mismatch.

Like @SimonW said if no mismatch occurs, the intermediate values will be used so the gradients w.r.t. W1 will be applied on W2.

I created a small example. Is this, what you would like to do?

x = Variable(torch.ones(1, 4))
w = Variable(torch.ones(4, 1), requires_grad=True)
y = Variable(torch.ones(1, 1))

output = torch.matmul(x, w)

criterion = nn.MSELoss()
optimizer = optim.SGD([w], lr=0.1)

loss = criterion(output, y)
#w.data.fill_(2) # comment / uncomment
loss.backward()

print('w {}grad {}'.format(w.data, w.grad.data))

optimizer.step()
print('w {}grad {}'.format(w.data, w.grad.data))
optimizer.zero_grad()

Hi @ptrblck,

I really appreciate your effort of writing up a code to explain what is happening. And yeah, your code captures what I was looking for.

I extended your code from one layer to 2 (since weights don’t affect gradients in a single layer net) to actually see the effect of changing weights after loss computation.

According to the observation, it looks like the gradients are computed during ‘backward()’ call using the new weights i.e. W2.

import torch
from torch.autograd import Variable
from torch import nn
from torch import optim

from graphviz import Digraph
# make_dot was moved to https://github.com/szagoruyko/pytorchviz
from torchviz import make_dot


w0 = Variable(torch.ones(2, 2), requires_grad=True) 
w = Variable(torch.ones(2, 1), requires_grad=True)

x = Variable(torch.ones(1, 2))
y = Variable(torch.ones(1, 1))

criterion = nn.MSELoss()
optimizer = optim.SGD([w0,w], lr=0.01)

for i in range(1):

    output = x.mm(w0).mm(w)

    # if i == 0:
    #     make_dot(output).view()

    optimizer.zero_grad()
    loss = criterion(output, y)
    # w0.data.fill_(2) # comment / uncomment
    # w.data.fill_(2) # comment / uncomment
    loss.backward()

    print('w0 {}grad {}'.format(w0.data, w0.grad.data))
    print('w {}grad {}'.format(w.data, w.grad.data))


    optimizer.step()
    print('w0 {}grad {}'.format(w0.data, w0.grad.data))
    print('w {}grad {}'.format(w.data, w.grad.data))
    optimizer.zero_grad()
1 Like

Thanks for updating the code. It was a indeed a split to use just one layer.
It seems, it’s using W2. However, I would use this approach carefully, since I’m not sure how the intermediate values are stored using other layers etc.

Could you tell a bit more about your use case?

I am trying to replicate the algo outlined in the MAML paper:

Pseudo code:

The idea is to compute a second order gradient. While computing the first order gradient is straightforward, computing the second order gradient as mentioned in this paper is a little bit tricky. The pseudo code should give a good understanding.

To implement MAML, you should use the nn.functional interface. Then you can pass in the weights manually.

Hi Simon! Thanks for your suggestion. Can you give a short example of how to use nn.functional interface so as to change the parameters w.r.t which the gradients will be computed during the backward pass?

As long as you have a way to map (input, weights) to the results, it will be fine. Depending on your needs, the following may or may not be the best approach.

class LeNet(nn.Module):
    def __init__(self, num_classes=10, use_dropout=True):
        super(LeNet, self).__init__()
        self.use_dropout = use_dropout
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, num_classes)

    def get_weights(self):
        return (self.conv1.weight, self.conv1.bias,
                self.conv2.weight, self.conv2.bias,
                self.fc1.weight, self.fc1.bias,
                self.fc2.weight, self.fc2.bias)

    def forward_with_weights(self, x,
                             conv1_w, conv1_b,
                             conv2_w, conv2_b,
                             fc1_w, fc1_b,
                             fc2_w, fc2_b):
        x = F.conv2d(x, conv1_w, conv1_b)
        x = F.relu(F.max_pool2d(x, 2))
        x = F.conv2d(x, conv2_w, conv2_b)
        if self.use_dropout:
            x = F.dropout2d(x, self.training)
        x = F.relu(F.max_pool2d(x, 2))
        x = x.view(-1, 320)
        x = F.linear(x, fc1_w, fc1_b)
        x = F.relu(x)
        if self.use_dropout:
            x = F.dropout(x, training=self.training)
        return F.linear(x, fc2_w, fc2_b)

    def forward(self, x):
        return self.forward_with_weights(x, *self.get_weights())
1 Like