Cache = self.state_dict() overwritten

I am experimenting with randomly removing individual weights after the model has been trained. In order to restore the original state of the network after these experiments, I cache them before. However, the restoring does not work as expected, can somebody point out my mistake?

[...]
cache = self.state_dict()
for layer in self.layers.parameters():
    print(layer.data[index])  # tensor(0.0034)

    layer.data[index] = 0
    print(layer.data[index])  # tensor(0.)

    [...] #evaluate the model using the disabled weight

    self.load_state_dict(cache)
    print(layer.data[index])  # tensor(0.)

Try to call cache = copy.deepcopy(self.state_dict()), as currently cache holds a reference to your state dict and will be updated with it.

2 Likes

Thanks! So I this seems to be a general misunderstanding on mutable data types in python (e.g. as explained on StackOverflow).

What I still don’t understand is, why the following does not change the state of the cache:

cache = self.state_dict()
weights  = [value for (key,value) in cache.items() if '.weight' in key]
weight_probs = [torch.full_like(w, 0.5) for w in self.layers.parameters()]

with torch.no_grad():
    current = [torch.bernoulli(p) for p in weight_probs]
    for i, layer in enumerate(self.layers.parameters()):
        layer.data = weights[i] * current[i].float()

self.load_state_dict(cache)  # this works, cache seems to be unaltered, why?

The code snippet does not change anything from cache.
It just uses some weight parameters and reads them in w.data = weights[i] * ....
If you would change something in weights, the state dict would also reference this change:

model = nn.Linear(1, 1)
state_dict = model.state_dict()
w = [value for (key, value) in state_dict.items()]
w[-1][0] = -10000.
print(w)
> [tensor([[-0.2974]]), tensor([-10000.])]
print(state_dict)
> OrderedDict([('weight', tensor([[-0.2974]])), ('bias', tensor([-10000.]))])
1 Like

I am sorry, but I cannot follow:
In both cases I iterate through model.layers.parameters() and change the weights in this loop. I do not directly change anything in the cached model.state_dict() in the first case as you imply with your last example. What am I missing?

EDIT: Is the difference in accessing an individual index (layer.data[index] = 0) compared to updating the whole layer with a new object (layer.data = torch.zeros_like(layer.data))?

Sorry for the misunderstanding.
Yes, basically if you update inplace, the references will still hold.
However, if you change the unterlying .data, you break the reference.
Have a look at this code:

model = nn.Linear(1, 1)
state_dict = model.state_dict()
weights = [value for key, value in state_dict.items()]

for i, (name, layer) in enumerate(model.named_parameters()):
    print(name)
    print('Before ', layer)
    print('Reference identical: ', layer.data.data_ptr()==weights[i].data_ptr())
    layer.data = weights[i] * 2
    print('After ', layer)
    print('Reference identical: ', layer.data.data_ptr()==weights[i].data_ptr())
1 Like