You could modify the state_dict
and reload it afterwards:
state_dict = model.state_dict()
state_dict['classifier.weight'] = torch.randn(10, 10)
model.load_state_dict(state_dict)
Also, if you are using the loop over named_parameters
, make sure to manipulate the parameters inplace and with a torch.no_grad()
guard:
with torch.no_grad():
for name, param in model.named_parameters():
if 'classifier.weight' in name:
param.copy_(torch.randn(10, 10))
I wouldn’t recommend the usage of the .data
attribute, as Autograd cannot track this operations and you might create silent bugs in your code.