Hi. I’m facing some issues when I’m trying to use PyTorch optimizers with Apex AMP. My environment is:
- OS: Ubuntu 18.04.5
- Python: 3.8.5
- PyTorch: 1.7.1
- CUDA: 11.0
- Apex: 0.9.10.dev0
- Transformers: 4.3.3
You can reproduce my error with the following code:
from apex import amp
from transformers import AutoModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
model = AutoModel.from_pretrained('bert-base-cased')
model = model.cuda()
new_layer = ["extractor", "bilinear"]
optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in new_layer)], },{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in new_layer)], "lr": 1e-4},]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5, eps=1e-10)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=4000, num_training_steps=10000)
Commenting out the model, optimizer = amp.initialize
line runs the code fine. However, running this script returns the following:
Traceback (most recent call last):
File "../train.py", line 33, in finetune
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
File "/home/User/anaconda3/lib/python3.8/site-packages/transformers/optimization.py", line 98, in get_linear_schedule_with_warmup
return LambdaLR(optimizer, lr_lambda, last_epoch)
File "/home/User/anaconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 205, in __init__
super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)
File "/home/User/anaconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 74, in __init__
self.optimizer.step = with_counter(self.optimizer.step)
File "/home/User/anaconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 56, in with_counter
instance_ref = weakref.ref(method.__self__)
AttributeError: 'function' object has no attribute '__self__'
What might the issue be, and how could I fix this? Thanks.