How to run PyTorch DDP with Hydra + Optuna?

I am trying to use Hydra + Optuna Sweeper with my PyTorch + DDP (mp.spawn) setup. I have the PyTorch + DDP running properly. However, when I’m trying to run the hydra sweeper, it;s not working - as in, it stops after just one case. So, I am wondering how to make it work and tune the hyperparameters?

In the below, the main function is launched by the manual_proc_runner spawned across GPUs.

def manual_proc_runner(rank, cfg, m, results): # for current process
    world_size = m.get_world_size() # 2

    manual_proc_setup(rank=rank, world_size=world_size)

    # ----- DDP Manual Env Setup -----
    env = CustomManualEnvironment(world_size=world_size, rank=rank)
    # main(m=m, env=env, args=args)
    loss = main(cfg=cfg, m=m, env=env)

    results[rank] = loss
    print(f'results:{results} in rank:{rank}')

    # clean up DDP process group after completion

@hydra.main(version_base=None, config_path='../config_files', config_name='config')
def hydra_manual_proc_launcher(cfg): # no arguments other than 'cfg' since hydra
    m = ModelConfig('./config_files/drivaer_domains_ddp.ini')
    world_size = m.get_world_size() # 2

    out_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    with mp.Manager() as manager:
        results = manager.list([None]* world_size)
        mp.spawn(manual_proc_runner, args=(cfg,m,results), nprocs=world_size, join=True)
        aggregated_loss = sum(filter(None,results)) / len(results)
        print(f'Aggregated loss: {aggregated_loss}')

        return aggregated_loss

def main(cfg, m, env):

This is what I have, along with the following config file. When I run this, default hyper-parameters are chosen and then it stops instead of running the whole optimization loop.

  - override hydra/sweeper: optuna
  - override hydra/sweeper/sampler: tpe

storage: 'sqlite:///../output/train/checkpoints/subdomains_211_v2/optuna_dashboard.db'

      seed: 123
    direction: minimize
    study_name: hydra-1
    # storage: null
    n_trials: 4 #50
    n_jobs: 1
      mp_iterations: range(1, 15)
      mlp_dim: range(64,256)
      mlp_layers: range(1,4)

mp_iterations: 15
mlp_dim: 64
mlp_layers: 2

# if true, simulate a failure by raising an exception
error: false