sing deepspeed in pytorch lightning, a bug occurred : RuntimeError: Function ConvolutionBackward0 returned an invalid gradient at index 1

I’m using DeepSpeedPlugin to speed up in pytorch lightning 1.4.2, however, a bug occurred during the backward process: RuntimeError: Function ConvolutionBackward0 returned an invalid gradient at index 1 - got [1280, 1280, 3, 3] but expected shape compatible with [0].

This confuse me. Because when i set accerlator to ddp, the model runs successfully. Also, this bug will appear even if pytorch lightning is 2.4.0.

What version are you seeing the problem on?

v1.x

How to reproduce the bug

def main() -> None:
    parser = ArgumentParser()
    parser.add_argument("--config", type=str, default='configs/train.yaml')
    args = parser.parse_args()

    config = OmegaConf.load(args.config)
    pl.seed_everything(config.lightning.seed, workers=True)

    data_module = instantiate_from_config(config.data)
    model = instantiate_from_config(OmegaConf.load(config.model.config))
    # TODO: resume states saved in checkpoint.
    if config.model.get("resume"):
        load_state_dict(model, torch.load(config.model.resume, map_location="cpu"), strict=True)

    callbacks = []
    for callback_config in config.lightning.callbacks:
        callbacks.append(instantiate_from_config(callback_config))

    logger =TensorBoardLogger(save_dir=config.lightning.trainer.default_root_dir,version=1,name="lightning_logs")
    trainer = pl.Trainer(callbacks=callbacks, logger=logger, plugins=DeepSpeedPlugin(stage=3), **config.lightning.trainer)
    trainer.fit(model, datamodule=data_module)

Error messages and logs

Traceback (most recent call last):                                                                                                                                                                                                                                                               
  File "/data3/sxhong/project/CCSR_deepspeed/train.py", line 36, in <module>                                                                                                                                                                                                                     
    main()                                                                                                                                                                                                                                                                                       
  File "/data3/sxhong/project/CCSR_deepspeed/train.py", line 32, in main                                                                                                                                                                                                                         
    trainer.fit(model, datamodule=data_module)                                                                                                                                                                                                                                                   
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit                                                                                                                                                                   
    self._run(model)                                                                                                                                                                                                                                                                             
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run                                                                                                                                                                  
    self._dispatch()                                                                                                                                                                                                                                                                             
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch                                                                                                                                                             
    self.accelerator.start_training(self)                                                                                                                                                                                                                                                        
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training                                                                                                                                                
    self.training_type_plugin.start_training(trainer)                                                                                                                                                                                                                                            
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training                                                                                                                             
    self._results = trainer.run_stage()                                                                                                                                                                                                                                                          
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage                                                                                                                                                             
    return self._run_train()                                                                                                                                                                                                                                                                     
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train                                                                                                                                                           
    self.fit_loop.run()                                                                                                                                                                                                                                                                          
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run                                                                                                                                                                        
    self.advance(*args, **kwargs)                                                                                                                                                                                                                                                                
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance                                                                                                                                                                
    epoch_output = self.epoch_loop.run(train_dataloader)                                                                                                                                                                                                                                         
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run                                                                                                                                                                        
    self.advance(*args, **kwargs)                                                                                                                                                                                                                                                                
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance                                                                                                                                               
    batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)  
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 101, in run
    super().run(batch, batch_idx, dataloader_idx)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 148, in advance                                                                                                                                              
    result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 202, in _run_optimization                                                                                                                                    
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 396, in _optimizer_step                                                                                                                                      
    model_ref.optimizer_step(
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1618, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 209, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 129, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 292, in optimizer_step                                                                                                                                              
    make_optimizer_step = self.precision_plugin.pre_optimizer_step(
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/deepspeed_precision.py", line 47, in pre_optimizer_step                                                                                                                              
    lambda_closure()  # DeepSpeed does not support closures
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 236, in _training_step_and_backward_closure                                                                                                                  
    result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 549, in training_step_and_backward                                                                                                                           
    self.backward(result, optimizer, opt_idx)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 590, in backward                                                                                                                                             
    result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 276, in backward
    self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/deepspeed_precision.py", line 60, in backward                                                                                                                                        
    deepspeed_engine.backward(closure_loss, *args, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1967, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/deepspeed/runtime/zero/stage3.py", line 2213, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/data3/sxhong/project/CCSR_deepspeed/ldm/modules/diffusionmodules/util.py", line 145, in backward
    input_grads = torch.autograd.grad(
  File "/data3/sxhong/miniconda3/envs/CCSR/lib/python3.9/site-packages/torch/autograd/__init__.py", line 303, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function ConvolutionBackward0 returned an invalid gradient at index 1 - got [1280, 1280, 3, 3] but expected shape compatible with [0] 

Environment

#- PyTorch Lightning Version (e.g., 2.4.0):
#- torch                       2.0.1+cu118
#- torchmetrics            0.6.0
#- torchvision              0.15.2+cu118
#- pytorch-lightning         1.4.2
#- deepspeed                 0.14.4