Same model and data somehow OOM on better device

I’m training a TTS model and the training process is fine on RTX4090, but when moving to a device with 8 *A100-SXM4-40GB, it somehow suddenly get cuda OOM error after thousands of epochs.

Here is the detail situation:
Both training on RTX4090 or A100 uses exactly same model, config, data, coding, cuda version(12.2)

  1. Epoch 1~4000 training on RTX4090, the GPU memory usage is as picture, everything is fine, no oom, no memory leak, no growing memory usage.
    4090

  2. Moving everything to 8 *A100-SXM4-40GB

  3. Epoch4001~7000 is still fine on A100 device

  4. Epoch 7001~7051 starts to get cuda OOM
    GPU memory usage when epoch 7005


    GPU memory usage near epoch 7050

  5. When loading data for epoch 7051, cuda OOM will occur
    oom

If there is a memory leak or coding problem, I suppose it should also happen when it is on 4090 or on A100 during epoch 4000 to epoch 7000, but it just suddenly happen after so many stable training progress, so I have no clue how this is happenning.

Any advice or clue is appreciated, thank you!

Here is the source code I used for training.
https://github.com/p0p4k/vits2_pytorch

The memory usage doesn’t look stable as in the first screenshot you are already seeing a heavy imbalanced. E.g. could you explain why device 0 uses approx. 9GB of it’s memory while device 3 uses approx. 18GB? Is each device getting different input shapes, which could increase the memory usage?

Honestly, I have no idea, it’s my first time training with multi-gpu, the memory usage just keep changing
and each GPU got different memory usage when loading data, I thought it’s normal…
gp

The memory usage won’t decrease unless you are clearing the cache. So could you explain why you are calling into torch.cuda.empty_cache()?

Just for record, I’m not the author of that github repository, and I’m not very experienced with building model from scratch or the torch code.

Here is the main function

def main():
    """Assume Single Node Multi GPUs Training Only"""
    assert torch.cuda.is_available(), "CPU training is not allowed."

    n_gpus = torch.cuda.device_count()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "6060"

    hps = utils.get_hparams()
    mp.spawn(
        run,
        nprocs=n_gpus,
        args=(
            n_gpus,
            hps,
        ),
    )

And part of the run function that initials data loader

def run(rank, n_gpus, hps):
    global global_step
    if rank == 0:
        logger = utils.get_logger(hps.model_dir)
        logger.info(hps)
        utils.check_git_hash(hps.model_dir)
        writer = SummaryWriter(log_dir=hps.model_dir)
        writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))

    dist.init_process_group(
        backend="nccl", init_method="env://", world_size=n_gpus, rank=rank
    )
    torch.manual_seed(hps.train.seed)
    torch.cuda.set_device(rank)

    train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
    train_sampler = DistributedBucketSampler(
        train_dataset,
        hps.train.batch_size,
        [32, 300, 400, 500, 600, 700, 800, 900, 1000],
        num_replicas=n_gpus,
        rank=rank,
        shuffle=True,
    )

    collate_fn = TextAudioCollate()
    train_loader = DataLoader(
        train_dataset,
        num_workers=8,
        shuffle=False,
        pin_memory=True,
        collate_fn=collate_fn,
        batch_sampler=train_sampler,
    )
    if rank == 0:
        eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
        eval_loader = DataLoader(
            eval_dataset,
            num_workers=8,
            shuffle=False,
            batch_size=1,
            pin_memory=True,
            drop_last=False,
            collate_fn=collate_fn,
        )

torch.cuda.empty_cache() is not called in the source code, if decreasing memory is not normal during training, what else could possibly cause it?