wonchulSon
(Wonchul Son)
August 18, 2020, 3:45pm
1
I add my_weight
to optimizer like below.
my_weight = nn.Parameter(torch.tensor(args.weight).cuda(), requires_grad=True)
self.optimizer.add_param_group({"params": [my_weight]})
And I want freeze when the iteration
is 3 times.
for batch_idx, (data, target) in enumerate(train_loader):
iteration += 1
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
if iteration % 3 == 0:
!!!!! I want freeze my_weight here !!!!!
else:
!!!!! I want unfreeze my_weight here !!!!!
loss.backward()
optimizer.step()
How can I do?
Thank you.
I did something like this, but definitely not very elegant…
>>> import torch
>>> import torch.nn as nn
>>>
>>> device = 'cuda:0'
>>>
>>> w = nn.Parameter(torch.randn(1).to(device), requires_grad=True)
>>> b = nn.Parameter(torch.randn(1).to(device), requires_grad=True)
>>>
>>> opt = torch.optim.Adam([w], 0.1)
>>> opt.add_param_group({'params': [b]})
>>>
>>> for i in range(9):
... opt.zero_grad()
... x = torch.randn(10, 1).to(device)
... y = 3.14 * x + 1.85
... y_pred = x @ w + b
... loss = ((y_pred - y) ** 2).mean()
... loss.backward()
... opt.step()
... if i % 3 == 0:
... b.requires_grad = False
... del opt.param_groups[1] # second param group as added it later, I know pretty inelegant
... del opt.state[b]
... elif not b.requires_grad:
... b.requires_grad = True
... opt.add_param_group({'params': [b]})
... print('Iteration: {}, w: {}, b: {}'.format(i, w.item(), b.item()))
...
Iteration: 0, w: 0.812088131904602, b: 1.4590768814086914
Iteration: 1, w: 0.7677664160728455, b: 1.4590768814086914
Iteration: 2, w: 0.7146307229995728, b: 1.5590769052505493
Iteration: 3, w: 0.6447014808654785, b: 1.611872673034668
Iteration: 4, w: 0.5945112705230713, b: 1.611872673034668
Iteration: 5, w: 0.5443614721298218, b: 1.51187264919281
Iteration: 6, w: 0.48854923248291016, b: 1.5435525178909302
Iteration: 7, w: 0.42736998200416565, b: 1.5435525178909302
Iteration: 8, w: 0.3671201169490814, b: 1.643552541732788