AttributeError: 'MultiStepLR' object has no attribute 'param_groups'

Hello, when I add a learning rate scheduler, the error appeared. I guess the part should be fixed are in
save_checkpoint or load_checkpoint shown below. How should I fixed the error? Thanks.

Functions:

import torch
import config
import torch.nn as nn
import random
import os
import numpy as np
from torchvision.utils import save_image
from skimage.metrics import structural_similarity as ssim

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

Training script:

from model3 import AutoEncoder
import torch
import numpy as np
from dataset import SurDataset
from utils import save_checkpoint, load_checkpoint, save_some_examples, seed_everything, initialize_weights
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import ssim_for_img_tensor

loss_list = []
ssim_list = []
def train_fn(epoch, loader, val_loader, model, optimizer, scheduler, loss_fn, scaler):
    loop = tqdm(loader)

    for idx, (csv_, target) in enumerate(loop):
        csv_ = csv_.to(config.DEVICE)
        target = target.to(config.DEVICE)
        with torch.cuda.amp.autocast():
            model.eval()
            predict = model(csv_)
            loss = loss_fn(predict, target)
            model.train()
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.step(scheduler)
        scaler.update()
        ssim=ssim_for_img_tensor(target, predict, mode='sigmoid')
        loss_list.append(loss.item())
        ssim_list.append(ssim)
        loop.set_description(f"{epoch+1}/{config.NUM_EPOCHS}")
        loop.set_postfix(loss=loss.item(), batch_avg_ssim=ssim)
        if idx % 1 == 0:
            save_some_examples(model, val_loader, epoch, folder=config.EVALUATE_FOLDERNAME, mode='sigmoid')
    if config.SAVE_MODEL:
        save_checkpoint(model, optimizer, filename=config.CHECKPOINT)
    ##################################################  plot  ###################################################
    plt.figure()

    plt.subplot(211)
    plt.plot(np.arange(1, epoch+2), loss_list, label='MSE Loss')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Loss vs Epochs Trend")
    plt.legend(loc = "best", fontsize=10)
    plt.xticks(np.arange(1, config.NUM_EPOCHS+1))

    plt.subplot(212)
    plt.plot(np.arange(1, epoch+2), ssim_list, label='SSIM Score')
    plt.xlabel("Epochs")
    plt.ylabel("SSIM Score")
    plt.title("SSIM vs Epochs Trend")
    plt.legend(loc = "best", fontsize=10)
    plt.xticks(np.arange(1, config.NUM_EPOCHS+1))

    plt.subplots_adjust(left=0.125,
                bottom=0.1, 
                right=0.9, 
                top=1, 
                wspace=0.2, 
                hspace=0.35)

    plt.savefig(f"{config.MODELNAME}_training_process.png")
    ##################################################  plot  ###################################################


def main():
    seed_everything(42)
    print(config.DEVICE)
    model = AutoEncoder(in_channels=1, out_channels=1).to(config.DEVICE)
    initialize_weights(model)
    optimizer = optim.Adam(list(model.parameters()),
                        lr=config.LEARNING_RATE,
                        betas=(0.5, 0.999))
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[7, 15, 23], gamma=0.1)
    mse = nn.MSELoss()
    scaler = torch.cuda.amp.GradScaler()
    dataset = SurDataset(csv_dir=config.CSV_DIR, img_dir=config.IMG_DIR, rescale=False)
    train_set, val_set = torch.utils.data.random_split(dataset, [12000, len(dataset) - 12000])
    loader = DataLoader(dataset=train_set, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(dataset=val_set, batch_size=config.BATCH_SIZE)

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT, model, optimizer, config.LEARNING_RATE,
        )
    model.train()
    for epoch in range(config.NUM_EPOCHS):
        train_fn(epoch=epoch, loader=loader, val_loader=val_loader, model=model, optimizer=optimizer, scheduler=scheduler, loss_fn=mse, scaler=scaler)
        
if __name__=='__main__':
    main()

Errors:

Traceback (most recent call last):
  File "C:\Users\PML\.conda\envs\floren\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Users\PML\.conda\envs\floren\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\PML\.vscode\extensions\ms-python.python-2021.6.944021595\pythonFiles\lib\python\debugpy\__main__.py", line 45, in <module>
    cli.main()
  File "c:\Users\PML\.vscode\extensions\ms-python.python-2021.6.944021595\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 444, in main
    run()
  File "c:\Users\PML\.vscode\extensions\ms-python.python-2021.6.944021595\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "C:\Users\PML\.conda\envs\floren\lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "C:\Users\PML\.conda\envs\floren\lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "C:\Users\PML\.conda\envs\floren\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\PML\Documents\Florentino_space\cast_iron_preprocess\autoencoder20210615\train2.py", line 96, in <module>      
    main()
  File "c:\Users\PML\Documents\Florentino_space\cast_iron_preprocess\autoencoder20210615\train2.py", line 93, in main
    train_fn(epoch=epoch, loader=loader, val_loader=val_loader, model=model, optimizer=optimizer, scheduler=scheduler, loss_fn=mse, scaler=scaler)
  File "c:\Users\PML\Documents\Florentino_space\cast_iron_preprocess\autoencoder20210615\train2.py", line 30, in train_fn      
    scaler.step(scheduler)
  File "C:\Users\PML\.conda\envs\floren\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 316, in step
    self.unscale_(optimizer)
  File "C:\Users\PML\.conda\envs\floren\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 267, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "C:\Users\PML\.conda\envs\floren\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 194, in _unscale_grads_
    for group in optimizer.param_groups:
AttributeError: 'MultiStepLR' object has no attribute 'param_groups'

It seems you are using mixed-precision training via torch.cuda.amp and are calling scaler.step(scheduler) instead of scaler.step(optimizer). Could you change it and rerun the script?