DDP loading kernels every epoch

I am running into an issue with multi-process DDP where it is loading my extra kernels on every run. Normally, this only happens at the beginning when I start the job. Since switching to DDP, it happens on every epoch.

This seems really wasteful. Am I doing something wrong s.t. it is doing that or is this expected behavior?

The following happens before every epoch:

Using /tmp/torch_extensions as PyTorch extensions root…
Detected CUDA files, patching ldflags
Emitting ninja build file /tmp/torch_extensions/fused/build.ninja…
Building extension module fused…
ninja: no work to do.
Loading extension module fused…
Using /tmp/torch_extensions as PyTorch extensions root…
Detected CUDA files, patching ldflags
Emitting ninja build file /tmp/torch_extensions/upfirdn2d/build.ninja…
Building extension module upfirdn2d…
ninja: no work to do.
Loading extension module upfirdn2d…
Using /tmp/torch_extensions as PyTorch extensions root…
Detected CUDA files, patching ldflags
Emitting ninja build file /tmp/torch_extensions/fused/build.ninja…
Using /tmp/torch_extensions as PyTorch extensions root…
Using /tmp/torch_extensions as PyTorch extensions root…
Building extension module fused…
ninja: no work to do.
Loading extension module fused…
Loading extension module fused…
Using /tmp/torch_extensions as PyTorch extensions root…
Using /tmp/torch_extensions as PyTorch extensions root…
Detected CUDA files, patching ldflags
Emitting ninja build file /tmp/torch_extensions/upfirdn2d/build.ninja…
Building extension module upfirdn2d…
ninja: no work to do.
Loading extension module upfirdn2d…
Loading extension module fused…
Using /tmp/torch_extensions as PyTorch extensions root…

The below is a stripped down version of my code.

def train(epoch, step, model, optimizer, scheduler, loader, args, gpu):
    model.train()

    averages = {'total_loss': Averager()}

    starting_step = step
    t = time.time()
    optimizer.zero_grad()
    for batch_idx, images in enumerate(loader):
        step += len(images)
        images = images.cuda(gpu)
        outputs = model(images)
        kl_zs, ll_losses, latents, generations = outputs[:4]
        prior_variances, posterior_variances = outputs[4:6]

        avg_kl_loss = torch.stack(kl_zs).mean()
        avg_ll_loss = torch.stack(ll_losses).mean()
        avg_kl_loss_penalized = avg_kl_loss * args.kl_lambda
        if args.kl_anneal:
            anneal_scale = max(0, min(step / args.kl_anneal_end, 1))
            avg_kl_loss_penalized *= anneal_scale
        total_loss = avg_ll_loss + avg_kl_loss_penalized

        averages['total_loss'].add(total_loss.item())

        total_loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()
        optimizer.zero_grad()

        if step - starting_step >= args.max_epoch_steps:
            break

    return averages['total_loss'].item(), step

def main(gpu, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group(backend='nccl', rank=gpu, world_size=args.num_gpus)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    train_loader, test_loader, image_shape = get_loaders_and_shape(args, rank=gpu)

    model, optimizer = get_model(args, image_shape, gpu)
    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    step = 0
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    total_epochs = args.num_epochs

    for epoch in range(start_epoch, start_epoch + total_epochs):
        train_loss, step = train(epoch,
                                 step,
                                 model,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 args,
                                 gpu)
        results['train_loss'].append(train_loss)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Independent variational objects.')

    parser.add_argument('--dataset',
                        default='gymnasticsRgb',
                        type=str)
    parser.add_argument('--num_workers',
                        default=4,
                        type=int,
                        help='number of data workers')
    parser.add_argument('--num_gpus',
                        default=1,
                        type=int,
                        help='number of gpus.')
    parser.add_argument('--debug',
                        action='store_true',
                        help='use debug mode (without saving to a directory)')
    parser.add_argument('--lr',
                        default=3e-4,
                        type=float,
                        help='learning rate assuming adam.')
    parser.add_argument('--weight_decay',
                        default=0,
                        type=float,
                        help='weight decay')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--max_epoch_steps', default=200000, type=int)
    parser.add_argument('--max_test_steps', default=50000, type=int)
    parser.add_argument('--num_epochs', default=250, type=int,
                        help='the number of epochs to train for. at 200000 ' \
                        'max_epoch steps, this would go for 2500 epochs to ' \
                        'reach 5e8 steps.')
    parser.add_argument(
        '--batch_size',
        default=100,
        type=int)
    parser.add_argument('--optimizer',
                        default='adam',
                        type=str,
                        help='adam or sgd.')
    parser.add_argument('--num_transition_layers', type=int, default=4)
    parser.add_argument('--num_latents', type=int, default=2)
    parser.add_argument('--latent_dim', type=int, default=32)
    parser.add_argument('--translation_layer_dim', type=int, default=128)
    parser.add_argument(
        '--output_variance',
        type=float,
        default=.25_

    args = parser.parse_args()
    mp.spawn(main, nprocs=args.num_gpus, args=(args,))

Could you share the code snippet where you actually load the extensions? I’m assuming you’re using load_inline to load your extra kernels? If so, is this happening for every epoch?