I’m trying to applying apex.amp
to recent detection transformer (DETR) code (link)
What I’m not sure is where to put amp.initialize
Here are lines from the main.py
of DETR where model and optimizer are declared (from line#121)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
param_dicts = [
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
According to the apex imagenet example, amp.initialize
should place before wrapping model with DDP. I can move DDP wrapping part to after the lr_scheduler
like below,
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
param_dicts = [
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
# Apex amp
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
But I wonder if I can pass model_without_ddp
to amp.initialize
after the end of above code without moving the DDP wrapping lines.
...
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
# Apex amp
model, optimizer = amp.initialize(model_without_ddp, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)