How to freeze custom learnable weight?

I add my_weight to optimizer like below.

my_weight = nn.Parameter(torch.tensor(args.weight).cuda(), requires_grad=True)
self.optimizer.add_param_group({"params": [my_weight]})

And I want freeze when the iteration is 3 times.

for batch_idx, (data, target) in enumerate(train_loader):
    iteration += 1
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

    if iteration % 3 == 0:
        !!!!! I want freeze my_weight here !!!!!
    else:
        !!!!! I want unfreeze my_weight here !!!!!

    loss.backward()
    optimizer.step()

How can I do?

Thank you.

I did something like this, but definitely not very elegant…

>>> import torch
>>> import torch.nn as nn
>>> 
>>> device = 'cuda:0'
>>> 
>>> w = nn.Parameter(torch.randn(1).to(device), requires_grad=True)
>>> b = nn.Parameter(torch.randn(1).to(device), requires_grad=True)
>>> 
>>> opt = torch.optim.Adam([w], 0.1)
>>> opt.add_param_group({'params': [b]})
>>> 
>>> for i in range(9):
...     opt.zero_grad()
...     x = torch.randn(10, 1).to(device)
...     y = 3.14 * x + 1.85
...     y_pred = x @ w + b
...     loss = ((y_pred - y) ** 2).mean()
...     loss.backward()
...     opt.step()
...     if i % 3 == 0:
...             b.requires_grad = False
...             del opt.param_groups[1] # second param group as added it later, I know pretty inelegant
...             del opt.state[b]
...     elif not b.requires_grad:
...             b.requires_grad = True
...             opt.add_param_group({'params': [b]})
...     print('Iteration: {}, w: {}, b: {}'.format(i, w.item(), b.item()))
... 
Iteration: 0, w: 0.812088131904602, b: 1.4590768814086914
Iteration: 1, w: 0.7677664160728455, b: 1.4590768814086914
Iteration: 2, w: 0.7146307229995728, b: 1.5590769052505493
Iteration: 3, w: 0.6447014808654785, b: 1.611872673034668
Iteration: 4, w: 0.5945112705230713, b: 1.611872673034668
Iteration: 5, w: 0.5443614721298218, b: 1.51187264919281
Iteration: 6, w: 0.48854923248291016, b: 1.5435525178909302
Iteration: 7, w: 0.42736998200416565, b: 1.5435525178909302
Iteration: 8, w: 0.3671201169490814, b: 1.643552541732788