I’m playing around with ProGAN and getting some weird behavior related to vram usage, especially at larger image size it seems to always take very close to the maximum amount on my gpu, even when lowering the batch size. This leads to it being quite unstable as I’m getting OOM error messages sometimes.
I believe there might be something related to my train function and that I might be doing something incorrectly here. Perhaps someone could take a look if something looks off. Specifically I’m a bit unsure of how to use the scaler when it comes to several optimizers
def train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
tensorboard_step,
writer,
scaler_gen,
scaler_critic,
):
loop = tqdm(loader, leave=True)
for batch_idx, (real, _) in enumerate(loop):
real = real.to(config.DEVICE)
cur_batch_size = real.shape[0]
for _ in range(config.CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step).reshape(-1)
critic_fake = critic(fake, alpha, step).reshape(-1)
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ config.LAMBDA_GP * gp
)
opt_critic.zero_grad()
scaler_critic.scale(loss_critic).backward(retain_graph=True)
scaler_critic.step(opt_critic)
scaler_critic.update()
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
fake = gen(noise, alpha, step)
gen_fake = critic(fake, alpha, step).reshape(-1)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
scaler_gen.scale(loss_gen).backward()
scaler_gen.step(opt_gen)
scaler_gen.update()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
)
alpha = min(alpha, 1)
This behavior is however consistent in FP32 and FP16 so I don’t know what could be wrong.