I am experimenting with randomly removing individual weights after the model has been trained. In order to restore the original state of the network after these experiments, I cache them before. However, the restoring does not work as expected, can somebody point out my mistake?
[...]
cache = self.state_dict()
for layer in self.layers.parameters():
print(layer.data[index]) # tensor(0.0034)
layer.data[index] = 0
print(layer.data[index]) # tensor(0.)
[...] #evaluate the model using the disabled weight
self.load_state_dict(cache)
print(layer.data[index]) # tensor(0.)
Thanks! So I this seems to be a general misunderstanding on mutable data types in python (e.g. as explained on StackOverflow).
What I still don’t understand is, why the following does not change the state of the cache:
cache = self.state_dict()
weights = [value for (key,value) in cache.items() if '.weight' in key]
weight_probs = [torch.full_like(w, 0.5) for w in self.layers.parameters()]
with torch.no_grad():
current = [torch.bernoulli(p) for p in weight_probs]
for i, layer in enumerate(self.layers.parameters()):
layer.data = weights[i] * current[i].float()
self.load_state_dict(cache) # this works, cache seems to be unaltered, why?
The code snippet does not change anything from cache.
It just uses some weight parameters and reads them in w.data = weights[i] * ....
If you would change something in weights, the state dict would also reference this change:
model = nn.Linear(1, 1)
state_dict = model.state_dict()
w = [value for (key, value) in state_dict.items()]
w[-1][0] = -10000.
print(w)
> [tensor([[-0.2974]]), tensor([-10000.])]
print(state_dict)
> OrderedDict([('weight', tensor([[-0.2974]])), ('bias', tensor([-10000.]))])
I am sorry, but I cannot follow:
In both cases I iterate through model.layers.parameters() and change the weights in this loop. I do not directly change anything in the cached model.state_dict() in the first case as you imply with your last example. What am I missing?
EDIT: Is the difference in accessing an individual index (layer.data[index] = 0) compared to updating the whole layer with a new object (layer.data = torch.zeros_like(layer.data))?
Sorry for the misunderstanding.
Yes, basically if you update inplace, the references will still hold.
However, if you change the unterlying .data, you break the reference.
Have a look at this code:
model = nn.Linear(1, 1)
state_dict = model.state_dict()
weights = [value for key, value in state_dict.items()]
for i, (name, layer) in enumerate(model.named_parameters()):
print(name)
print('Before ', layer)
print('Reference identical: ', layer.data.data_ptr()==weights[i].data_ptr())
layer.data = weights[i] * 2
print('After ', layer)
print('Reference identical: ', layer.data.data_ptr()==weights[i].data_ptr())