Applying Apex amp to DETR

I’m trying to applying apex.amp to recent detection transformer (DETR) code (link)
What I’m not sure is where to put amp.initialize
Here are lines from the main.py of DETR where model and optimizer are declared (from line#121)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

According to the apex imagenet example, amp.initialize should place before wrapping model with DDP. I can move DDP wrapping part to after the lr_scheduler like below,

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model  

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    # Apex amp
    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=args.opt_level,
                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                      loss_scale=args.loss_scale
                                      )
    
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

But I wonder if I can pass model_without_ddp to amp.initialize after the end of above code without moving the DDP wrapping lines.

...
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
    
    # Apex amp
    model, optimizer = amp.initialize(model_without_ddp, optimizer,
                                      opt_level=args.opt_level,
                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                      loss_scale=args.loss_scale
                                      )

We recommended to setup the mixed precision training before wrapping the model into DDP.
Wouldn’t this work for you or what the reason you would like to setup DDP before amp?

Also, note that we recommend trying out the native amp implementation using the nightly binaries, as explained here.

Thanks for your help.
Not a important reason. I thought it might be possible to add amp without any changes of original code if the second option is valid.

I also considered to use native amp in nightly version.
But I was not sure it would be better to use the nightly version or wait until stable version is released.
Maybe I should try native amp also.

@ptrblck, while I’m trying 1st option using apex, I got following error.

Traceback (most recent call last):
  File "main_apex.py", line 269, in <module>
    main(args)
  File "main_apex.py", line 219, in main
    args.clip_max_norm)
  File "/home/detr/engine_apex.py", line 37, in train_one_epoch
    loss_dict = criterion(outputs, targets)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/detr/models/detr.py", line 225, in forward
    indices = self.matcher(outputs_without_aux, targets)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/detr/models/matcher.py", line 71, in forward
    cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
  File "/usr/local/lib/python3.7/dist-packages/torch/functional.py", line 732, in cdist
    return _VF.cdist(x1, x2, p, None)
RuntimeError: "cdist_cuda" not implemented for 'Half'

It seems torch.cdist function is not working with half precision.
Is there any workaround?
Can this issue be resolved in nightly version?

You could decorate the method with @amp.float_function (maybe you would have to wrap it into another helper method), which should disable the casting.

This might be the case. If you see that it’s not working in native amp, let us know, please.

I’m working on native amp and found that above issue is not occurred because torch.cdist is included in Ops that autocast to float32.

However, I faces another issue. The targets in the code which is a list of dict of tensors (line) and it passes to the criterion.
It seems autocast cannot handle this type of input.
What is the best way to make it work with the autocast?
Currently, I disable autocast at each loss and cast the outputs and targets to float following this.

It seems loss functions are mostly calculated with float32. Is it right?
I wonder if there is a float_function like decorator for the method (that you informed above) in native amp so that I can decorate the loss function such as this.

Hi, did you figure it out? I’ve encountered the same issue when applying native apex to DETR.
I’ve tried decorating SetCriterion.foward() with amp.float_function but the errors remain.