Apex with `torch.optim.swa_utils`

I am currently trying to using apex with SWA like so:

model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
model = DDP(model)

swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
swa_start = 6
swa_scheduler = SWALR(optimizer, swa_lr=0.05 * world_size)

At this line (swa_model = torch.optim.swa_utils.AveragedModel(model)) I am getting the following error:

Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/jupyter/Flood_Comp/starter.py", line 249, in train
    swa_model = torch.optim.swa_utils.AveragedModel(model)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/optim/swa_utils.py", line 89, in __init__
    self.module = deepcopy(model)
  File "/opt/conda/lib/python3.7/copy.py", line 169, in deepcopy
    rv = reductor(4)
  File "/opt/conda/lib/python3.7/site-packages/apex/parallel/distributed.py", line 271, in __getstate__
    del attrs['self.bucket_streams']
KeyError: 'self.bucket_streams'

Any pointers on mitigating this would be helpful.

apex.amp is deprecated in favor of the native torch.cuda.amp implementation and we recommend to switch to the latter.
More details are given in this post.

Okay. Strangely enough, I am unable to get the apex.amp benefits in torch.cuda.amp. But I look into the post you suggested and see if there’s anything I am missing out on.

After following @ptrblck’s suggestions here’s how my train() function is looking like (consider this to be the launcher expected by torch.multiprocessing.spawn().

def train(rank, num_epochs, world_size):
    init_process(rank, world_size)
    torch.manual_seed(0)
    
    model = create_model()
    torch.cuda.set_device(rank)
    model.cuda(rank)
    model = DistributedDataParallel(model, device_ids=[rank])
    swa_model = torch.optim.swa_utils.AveragedModel(model)

    learning_rate = 1e-3
    optimizer = torch.optim.Adam(model.parameters(), 
                             lr=learning_rate * world_size)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    swa_start = 10
    swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=0.05)
    criteria = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    
    train_loader, val_loader = get_dataloader(rank, world_size)
    
    for epoch in range(num_epochs):
    	model.train()
        for batch in train_loader:
            with torch.cuda.amp.autocast(enabled=True):
                image = batch['image'].cuda(rank, non_blocking=True)
                mask = batch['mask'].cuda(rank, non_blocking=True)
                
                pred = model(image)
                
                loss = criteria(pred, mask.unsqueeze(1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            
        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

Things are now working as expected. @ptrblck if I could do anything better to further optimize the performance please let me know.

Majority of the SWA code comes from the official docs.

@ptrblck I am currently running into another problem that is closely related. After training with SWA, we need to update the batch norm statistics (reference). Since the structure of my dataset is different from what torch.optim.swa_utils.update_bn() expects, I am doing the following inside train() (recall that train() is the launcher I provide to mp.spawn()):

if rank == 0:
    for batch in train_loader: 
        image = batch['image'].cuda(rank, non_blocking=True)
        prediction = swa_model(image)

This leads to the following error:

Traceback (most recent call last):
  File "starter.py", line 343, in <module>
    nprocs=WORLD_SIZE, join=True
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/jupyter/Flood_Comp/starter.py", line 334, in train
    prediction = swa_model(image)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/optim/swa_utils.py", line 101, in forward
    return self.module(*args, **kwargs)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 610, in forward
    self._sync_params()
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 1048, in _sync_params
    authoritative_rank,
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 979, in _distributed_broadcast_coalesced
    self.process_group, tensors, buffer_size, authoritative_rank
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:575] Connection closed by peer [10.138.0.33]:26791

Anything I am missing out on?

Are you seeing this error using mp.spawn and swa in isolation or only in combination with torch.cuda.amp?

In combination with the three. I did not try out the isolated part.

Could you try to isolate it further, which would help to debug it more? I.e. in particular it would be interesting to see, if your custom mp approach would work with amp or swa in isolation, as this is often causing trouble, if you are not careful.

With just amp, it works fine. This I have verified before.

@ptrblck any further suggestions?

I’m not familiar with the internals of swa so you could check the multiprocessing best-practices as well as the docs about sharing CUDA tensors.