i tried your solution like this
def train_one_epoch(dataloader, model, scheduler, optimizer, scaler, epoch):
torch.manual_seed(42)
model.train()
total_loss = 0
pbar = tqdm(dataloader, desc=f"Train: Epoch {epoch + 1}", total=len(dataloader), mininterval=5)
for img, target in pbar:
img = img.to(device)
if(CFG.fp16):
optimizer.zero_grad()
with autocast(enabled=True):
outputs = model(img)
target = target.unsqueeze(1)
target = target.to(float).to(device)
loss = criterion(outputs, target)
if np.isinf(loss.item()) or np.isnan(loss.item()):
print(f'Bad loss, skipping the batch ')
del loss, outputs
gc.collect()
continue
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, norm_type=2.0)
scaler.unscale_(optimizer)
scaler.step(optimizer)
#https://discuss.pytorch.org/t/userwarning-detected-call-of-lr-scheduler-step-before-optimizer-step/164814/2
old_scaler = scaler.get_scale()
scaler.update()
new_scaler = scaler.get_scale()
if new_scaler < old_scaler:
if scheduler is not None:
scheduler.step()
else:
print("old_scaler ",old_scaler,"new_scaler ",new_scaler)
then i get userWarning like this
Train: Epoch 1: 0%| | 0/511 [00:00<?, ?it/s]/home/ansary/anaconda3/envs/mobassir/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:138: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
Train: Epoch 1: 0%| | 1/511 [00:08<50:04, 5.89s/it, loss=1.63, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 0%| | 1/511 [00:09<50:04, 5.89s/it, loss=2.77, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 0%| | 1/511 [00:10<50:04, 5.89s/it, loss=1.59, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 0%| | 1/511 [00:10<50:04, 5.89s/it, loss=2.04, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 2%| | 10/511 [00:11<08:18, 1.00it/s, loss=1.56, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 2%| | 10/511 [00:12<08:18, 1.00it/s, loss=2.1, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 2%| | 10/511 [00:12<08:18, 1.00it/s, loss=1.39, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 2%| | 10/511 [00:13<08:18, 1.00it/s, loss=2.01, lr=0.000225]
old_scaler 2048.0 new_scaler 2048.0
Train: Epoch 1: 2%| | 10/511 [00:14<08:18, 1.00it/s, loss=1.63, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 2%| | 10/511 [00:15<08:18, 1.00it/s, loss=1.6, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 2%| | 10/511 [00:15<08:18, 1.00it/s, loss=1.13, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 4%| | 18/511 [00:16<06:34, 1.25it/s, loss=1.15, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 4%| | 18/511 [00:17<06:34, 1.25it/s, loss=0.942, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 4%| | 18/511 [00:17<06:34, 1.25it/s, loss=0.945, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 4%| | 18/511 [00:18<06:34, 1.25it/s, loss=0.887, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 4%| | 18/511 [00:19<06:34, 1.25it/s, loss=0.543, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
Train: Epoch 1: 4%| | 18/511 [00:19<06:34, 1.25it/s, loss=0.464, lr=0.000225]
old_scaler 1024.0 new_scaler 1024.0
then i tried if new_scaler <= old_scaler:
but no luck,i get same userWarning over and over again while using batch onecycle scheduler with fp16 enabled. i was using pytorch 2.0
@ptrblck any guess?