Bug in Quantization example?

When i follow tutorial from https://github.com/pytorch/vision/blob/39021408587eb252ebb842a54195d840d6d76095/references/classification/train_quantization.py, when i reload pretrained qat model, it will cause an error:

Traceback (most recent call last):
  File "train_quantization.py", line 461, in <module>
    main(args)
  File "train_quantization.py", line 311, in main
    args.print_freq, priors)
  File "train_quantization.py", line 67, in train_one_epoch
    optimizer.step()
  File "/opt/conda/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 67, in wrapper
    return wrapped(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/optim/sgd.py", line 106, in step
    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
RuntimeError: expected device cpu but got device cuda:0

I notice the optimizer reload error, when i use a new optimizer it will be ok

if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
    optimizer = torch.optim.SGD(
        model.parameters(), lr=args.lr, momentum=args.momentum,
        weight_decay=args.weight_decay)
model.to(device)

I use the newest pytorch docker image form dockerhub.

Hi @xieydd, there have a been two bug fixes for these issues, both landed in master. One of the fixes updated device affinity for enabling distributed mode, and another updated device affinity for converting a model to QAT. Would you mind trying with the nightly build? All of the fixes will be a part of the upcoming v1.6 release as well.

If it’s hard to get the nightly, for a quick fix, inserting model.to(device) after line 71 should work.

The error is after when i set model.to(device), i will clear it.

The description of mine is not clear?

would you mind letting us know:

  • the pytorch version you are using
  • the exact command you are using to call the tutorial script?

that would help get a local repro