using CosineAnnealingWarmRestarts in this setup causes improper behaviour , what could possibly be wrong with the training loop?
def train_fn(loader, model, optimizer, loss_fn, scaler):
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)
loop = tqdm(loader)
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device=DEVICE)
targets = targets.float().unsqueeze(1).to(device=DEVICE)
# forward
with torch.cuda.amp.autocast():
predictions = model(data)
loss = loss_fn(predictions, targets)
# backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
lr_scheduler.step()
scaler.update()
# update tqdm loop
loop.set_postfix(loss=loss.item())