Issue using ._parameters internal method

I’m trying to access model parameters using the internal ._parameters method. When I define the model as below, I get model parameters without any issue

model = nn.Linear(10, 10)
print(model._parameters)

However, when I use this method to get parameters of a model defined as a class, I get an empty OrderedDict().

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc(x)

model = MyModel()
print(model._parameters)

Is there a solution to this using ._parameters?

NOTE: I understand that using internal methods are frowned upon.

Hi,

Could you give more details on why you use this instead of model.parameters() ?

A number of models that differentiate through an optimization algorithm use this in order to maintain the differentiability of the model. Here is a meta-learning package that uses this approach (see L. 273) and here is another user that uses this to solve a similar problem.

I think a nicer solution here is to just del/set the attribute so that you don’t get into trouble:

module._parameters[param_key] = memo[p]

# Can become
delattr(module, param_key)
setattr(module, param_key, memo[p])

Taking a step back, we have a new experiemental “stateless” version of Module that might help you here (pytorch/_stateless.py at 8532061bce8da8b5fe9ecce1067ade16793a7ee3 · pytorch/pytorch · GitHub)
In particular, you can call your Module with a set of params without having to set them as parameters with functional_call(model, temporary_params, input).

So, re the first part, when I run:

delattr(model._modules[key], param_key)
setattr(model._modules[key], param_key, updated)

print(model._modules[key]._parameters[param_key])

I get the error:

print(model._modules[key]._parameters[param_key])

KeyError: ‘weight’

where param_key is weight. The point being, when I loop over parameters using

for param_key in model._modules[key]._parameters:
    # module._modules[key]._parameters[param_key] = memo[p]  # old

    delattr(module, param_key)
    setattr(module, param_key, memo[p])

upon the second iteration I get the error:

for param_key in model._modules[key]._parameters:

RuntimeError: OrderedDict mutated during iteration

Re the second part, I tried using _stateless with the snippet you provided here, and I’m getting the following error:

for name, tensor in parameters_and_buffers.items():

AttributeError: ‘generator’ object has no attribute ‘items’

when I

print(params)

I get

<generator object Module.named_parameters at 0x7f98f8dc1450>

My model parameters include:

self.fc = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.params_list = nn.ParameterList([self.fc1.weight, self.fc1.bias])

This is expected. If you want plain Tensors instead of Parameters, then they won’t be registered as parameters anymore (and so won’t be in _parameters).
But accessing the field on your Module will work as you expect and you will be able to differentiate through it.

AttributeError: ‘generator’ object has no attribute ‘items’

Ho that’s an oversight on our end.
Could you open an issue on github asking to add support for generator object on that API please?
As a workaround, you can use dict(model.named_parameters()) to make an actual dict from the generator.

Sure! will do. thanks!

The request was shoot down.

Let’s pursue that conversation on the issue.
Hopefully you should be unblocked for now without it by adding a dict() call right?

That is correct. Thanks!

I wrote a short snippet to test ‘_stateless’. In my example I tried to update weights of a network in an inner optimization loop and to learn the learning rate of the weight updates in an outer optimization loop (meta-optimization). I’m getting the error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 10]], which is output 0 of AsStridedBackward0, is at version 12; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

My code snippet is:

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

from torch.nn.utils import _stateless


class MyDataset(Dataset):
    def __init__(self, N):
        self.N = N
        self.x = torch.rand(self.N, 10)
        self.y = torch.randint(0, 3, (self.N,))

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 3)

        self.relu = nn.ReLU()

        self.alpha = nn.Parameter(torch.randn(1))
        self.beta = nn.Parameter(torch.randn(1))

    def forward(self, x):
        y = self.relu(self.fc1(x))
        return self.fc2(y)

epochs = 20
N = 100
dataset = DataLoader(dataset=MyDataset(N), batch_size=10)
model = MyModel()
loss_func = nn.CrossEntropyLoss()

optim = optim.Adam([model.alpha], lr=1e-3)

params = dict(model.named_parameters())
for i in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (x, y) in enumerate(dataset):
        logits = _stateless.functional_call(model, params, x)             # predict
        loss_inner = loss_func(logits, y)                                 # loss
        optim.zero_grad()                                                 # reset grad
        loss_inner.backward(create_graph=True, inputs=params.values())    # compute grad
        train_loss += loss_inner.item()                                   # store loss
        for k, p in params.items():
            if k is not 'alpha' and k is not 'beta':
                p.update = - model.alpha * p.grad
                params[k] = p + p.update                      # update weight

    print('Train Epoch: {}\tLoss: {:.6f}'.format(i, train_loss / N))
    logits = _stateless.functional_call(model, params, x)                 # predict
    loss_meta = loss_func(logits, y)
    loss_meta.backward()
    loss_meta.step()

From the error message, I understand that the issue comes from weight update for the weights of the second layer of the network, which points to an error in my inner loop optimization. Do you have suggestions? Thanks!

PS: I thought this is not directly relevant to the feature request, but I can move my question to the github issue if you prefer to discuss there. Thank you!

This is indeed unrelated.

If you enable anomaly mode, you will see that the problem is that some of the params are saved for backward but modified inplace. The fix is to make sure they are not:

for i in range(epochs):
    model.train()
    train_loss = 0
    params = dict(model.named_parameters()) # !! Extract these for each inner loop
    for batch_idx, (x, y) in enumerate(dataset):
        params = {k: v.clone() for k,v in params.items()} # !! Make sure the modified params are not the ones you use
        logits = _stateless.functional_call(model, params, x)             # predict
        # ... rest of your code
1 Like

Thanks for the solution! Just for my own understanding, what is params = {k: v.clone() for k, v in params.items()} exactly doing?

It is cloning the Tensors inside params. That means that the parameters that get used in the forward have different memory than the ones that got modified inplace by the optimizer.

1 Like

I was trying to add a Hebbian learning term (based on which weight update is proportional to the product of pre and post-synaptic activations) to the inner optimization. To that end, I modified the code as following:

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

from torch.nn.utils import _stateless


class MyDataset(Dataset):
    def __init__(self, N):
        self.N = N
        self.x = torch.rand(self.N, 10)
        self.y = torch.randint(0, 3, (self.N,))

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 3)

        self.relu = nn.ReLU()

        self.alpha = nn.Parameter(torch.randn(1))
        self.beta = nn.Parameter(torch.randn(1))

    def forward(self, x):
        y = self.relu(self.fc1(x))
        return (x, y), self.fc2(y)

def Optim(params, alpha, beta, y, logits):

    # -- add network output to activations
    softmax = nn.Softmax(dim=1)
    activations = []
    for item in y:
        activations.append(item)
    activations.append(softmax(logits))

    i = 0
    for k, p in params.items():
        print(k[4:])
        if k[4:] == 'weight':
            p.update = - alpha * p.grad + beta * torch.matmul(activations[i+1].T, activations[i])
            params[k] = p + p.update   # update weight
        elif k[4:] == 'bias':
            p.update = - alpha * p.grad + beta * activations[i + 1].squeeze(0)
            params[k] = p + p.update   # update weight
            i += 1

epochs = 20
N = 10
dataset = DataLoader(dataset=MyDataset(N), batch_size=1)
model = MyModel()
loss_func = nn.CrossEntropyLoss()
optim = optim.Adam([model.alpha], lr=1e-3)
torch.autograd.set_detect_anomaly(True)

for i in range(epochs):
    model.train()
    train_loss = 0
    params = dict(model.named_parameters())
    for batch_idx, (x, y) in enumerate(dataset):
        params = {k: v.clone() for k, v in params.items()}
        activations, logits = _stateless.functional_call(model, params, x)             # predict
        loss_inner = loss_func(logits, y)                                 # loss
        loss_inner.backward(create_graph=True, inputs=params.values())    # compute grad
        train_loss += loss_inner.item()                                   # store loss
        Optim(params, model.alpha, model.beta, activations, logits)

    print('Train Epoch: {}\tLoss: {:.6f}'.format(i, train_loss / N))
    activations, logits = _stateless.functional_call(model, params, x)                 # predict
    loss_meta = loss_func(logits, y)
    loss_meta.backward()
    optim.step()

This gives me the error

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 11; expected version 10 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

when I change the line

    Optim(params, model.alpha, model.beta, activations, logits)

to

    Optim(params, model.alpha, model.beta.clone(), activations, logits)

the problem vanishes. Since model.alpha does not need cloning,

  1. I’m wondering if using model.beta.clone() is the correct solution?
  2. Why unlike model.beta, cloning is not needed when passing model.alpha to Optim?