Does DistributedOptimizer support zero_grad and lr_scheduling?

I try to use distributed_rpc to implement parameter server. I would like to use zero_grad and lr_scheduling feature in the trainer. But it seems that the DistributedOptimizer does not support this. Is there any workaround?

Hey @Kunlin_Yang

zero_grad

The reason DistributedOptimizer does not provide a zero_grad API is because the gradients of each backward pass is stored in its own dedicated context (instead of in param.grad), and the context will be cleared when exiting the with dist_autograd.context() as context_id: scope. So zero_grad is not needed here. We can certainly add it if necessary, and it won’t be too hard to implement it in the application code either. E.g.,

def zero_grad(pr, context_id):
    dist_autograd.get_gradients(context_id)[pr.local_value()].zero_()


with dist_autograd.context() as context_id:
    # omitting forward-backward-optstep here
    futs = []
    for pr in param_rrefs:
        futs.append(rpc.rpc_async(pr.owner(), zero_grad, args=(pr, context_id)))
    [fut.wait() for fut in futs]

Or is there a different reason you would like to use the zero_grad API?

lr_scheduling

Currently, there is no distributed implementation or lr scheduling yet. I created an issue to track this: https://github.com/pytorch/pytorch/issues/38548

For now, you will need to do that using the raw RPC API. You can access the RRefs of the remote optimizers through DistributedOptimizer().remote_optimizers, so it can be sth like:

def create_lr_schheduler(opt_rref):
     # create and return lr_schheduler

def lrs_step(lrs_rref):
    lrs_rref.local_value().step()

opt = DistributedOptimizer(...)
lrs_rrefs = []
for opt_rref in opt.remote_optimizers:
    lrs_rrefs = rpc.remote(opt_rref.owner(), create_lr_schheduler, args=(opt_rref,))

with dist_autograd.context() as context_id:
    # omitting forward-backward-optstep here
    futs = []
    for lrs_rref iin lrs_rrefs:
        futs.append(rpc.rpc_async(lrs_rref.owner(), lrs_step, args=(lrs_rref,)))
    [fut.wait() for fut in futs]

If you are using master branch, the above code can be simplified with RRef.rpc_async() API.

2 Likes

I am also using distributed rpc and trying to implement lr_scheduler. I tried your suggestion, however I am not able to update the learning rate. It still prints the initial learning rate. It would be great if you can help.

I am inheriting the Distributed optimizer as follows:

from torch.distributed.optim import DistributedOptimizer

class MyDistributedOptimizer(DistributedOptimizer):  
    def __init__(self, *args, **kwargs):  
        super().__init__(*args, **kwargs)  
      
    def fetch_learning_rate(self, optimizer_rref):  
        optimizer_local = optimizer_rref.local_value()  
        return optimizer_local.optim.defaults['lr']  
  
    def get_learning_rates(self):  
        learning_rates = {}  
        for optimizer in self.remote_optimizers:  
            if optimizer.owner() == rpc.get_worker_info():  
                optimizer_local = optimizer.local_value()  
                learning_rates_on_owner = optimizer_local.optim.defaults['lr']  
                learning_rates[f'worker_{optimizer.owner().name}'] = learning_rates_on_owner  
            else:  
                learning_rates_on_owner = rpc.rpc_sync(  
                    optimizer.owner(),  
                    self.fetch_learning_rate,  
                    args=(optimizer,)  
                )  
                learning_rates[f'worker_{optimizer.owner().name}'] = learning_rates_on_owner  
        return learning_rates  

The optimizer is called as follows:

    optimizer = MyDistributedOptimizer( 
        optim.AdamW,
        model.parameter_rrefs(),
        lr=config.learning_rate,) 

I have defined the custom scheduler as follows:

class CustomLRScheduler:
    def __init__(self, distributed_optimizer_rref, num_warmup_steps, num_training_steps, num_cycles=0.5):
        self.distributed_optimizer_rref = distributed_optimizer_rref
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        self.num_cycles = num_cycles
        self.last_epoch = 0
        self.learning_rate = 0.0

    def step(self):
        self.last_epoch += 1
        self.learning_rate = self._get_lr(self.last_epoch)
        print(f"Updating learning rates to: {self.learning_rate}")
        try:
            rpc.rpc_sync(
                self.distributed_optimizer_rref.owner(),
                CustomLRScheduler.update_learning_rates,
                args=(self.distributed_optimizer_rref, self.learning_rate)
            )
        except Exception as e:
            print(f"Failed to update learning rates: {e}")

    def get_current_lr(self):
        return f"{self.learning_rate:.1e}"

    def _get_lr(self, current_step):
        if current_step < self.num_warmup_steps:
            return float(current_step) / float(max(1, self.num_warmup_steps))
        progress = float(current_step - self.num_warmup_steps) / float(max(1, self.num_training_steps - self.num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))

    @staticmethod
    def update_learning_rates(optimizer_rref, new_lr):
        optimizer_local = optimizer_rref.local_value()
        print(f"Worker updating learning rate to: {new_lr}")
        
        # Update the defaults dictionary
        optimizer_local.optim.defaults['lr'] = new_lr

def get_lr(lrs_rref):  
    return lrs_rref.rpc_sync().get_current_lr()  
  
def create_lr_scheduler(distributed_optimizer_rref, num_warmup_steps, num_training_steps, num_cycles=0.5):  
    scheduler = CustomLRScheduler(distributed_optimizer_rref, num_warmup_steps, num_training_steps, num_cycles)  
    return scheduler  
  
def lrs_step(lrs_rref):  
    lrs = lrs_rref.local_value()  
    lrs.step()

In the training loop, I am doing the following:

    lrs_rrefs = [  
        rpc.remote(  
            opt_rref.owner(),  
            create_lr_scheduler,  
            args=(opt_rref, num_warmup_steps, num_training_steps)  
        ) for opt_rref in optimizer.remote_optimizers  
    ] 

with dist_autograd.context() as context_id:
                #forward/backward steps.........
                # Step the learning rate scheduler  
                futs = []  
                for lrs_rref in lrs_rrefs:  
                    futs.append(rpc.rpc_async(lrs_rref.owner(), lrs_step, args=(lrs_rref,)))  
                [fut.wait() for fut in futs]  

                current_lrs = optimizer.get_learning_rates()
                print(f"Current learning rates: {current_lrs}")  
                logs = {"loss": loss.detach().item(),  "learning rates": current_lrs, "step": global_step}