Just for record, I’m not the author of that github repository, and I’m not very experienced with building model from scratch or the torch code.
Here is the main function
def main():
"""Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed."
n_gpus = torch.cuda.device_count()
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6060"
hps = utils.get_hparams()
mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
And part of the run function that initials data loader
def run(rank, n_gpus, hps):
global global_step
if rank == 0:
logger = utils.get_logger(hps.model_dir)
logger.info(hps)
utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
dist.init_process_group(
backend="nccl", init_method="env://", world_size=n_gpus, rank=rank
)
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32, 300, 400, 500, 600, 700, 800, 900, 1000],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioCollate()
train_loader = DataLoader(
train_dataset,
num_workers=8,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
)
if rank == 0:
eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
eval_loader = DataLoader(
eval_dataset,
num_workers=8,
shuffle=False,
batch_size=1,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
torch.cuda.empty_cache()
is not called in the source code, if decreasing memory is not normal during training, what else could possibly cause it?