Vram maxed (even when changing the batch size?)

I’m playing around with ProGAN and getting some weird behavior related to vram usage, especially at larger image size it seems to always take very close to the maximum amount on my gpu, even when lowering the batch size. This leads to it being quite unstable as I’m getting OOM error messages sometimes.

I believe there might be something related to my train function and that I might be doing something incorrectly here. Perhaps someone could take a look if something looks off. Specifically I’m a bit unsure of how to use the scaler when it comes to several optimizers

def train_fn(
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]

        for _ in range(config.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)

            with torch.cuda.amp.autocast():
                fake = gen(noise, alpha, step)
                critic_real = critic(real, alpha, step).reshape(-1)
                critic_fake = critic(fake, alpha, step).reshape(-1)
                gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
                loss_critic = (
                    -(torch.mean(critic_real) - torch.mean(critic_fake))
                    + config.LAMBDA_GP * gp


        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            gen_fake = critic(fake, alpha, step).reshape(-1)
            loss_gen = -torch.mean(gen_fake)


        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
        alpha = min(alpha, 1)

This behavior is however consistent in FP32 and FP16 so I don’t know what could be wrong.

Are you getting the OOMs in the same iterations with the large and small batch sizes?
If so, are you changing the input shapes or are you using static shapes?
If you lower the batch size e.g. to 1, are you still seeing OOMs?
Did you check the GPU memory usage via nvidia-smi or with PyTorch methods (e.g. torch.cuda.memory_summary())?
Note that the former would return all allocated memory, i.e. the CUDA context, the used and allocated memory as well as the cache in PyTorch, and the memory usage of all other applications.

1 Like

So I don’t know how but I don’t seem to be getting this problem anymore. Normally though if I change the batch size I can see the allocated memory decrease but in this situation it was always maxed. I have encountered this behavior in TensorFlow but never before in Pytorch. Sometimes (but rarely) this would cause an OOM randomly during an epoch (was using static shapes).

Interestingly I also upgraded my Pytorch version which resulted in VRAM halved and an epoch speed up by ~33%, which is pretty insane…

Edit: The performance improvements seems to be specific to the 30-series GPU (was just slow to notice this):