Any idea why I am getting AssertionError: No inf checks were recorded for this optimizer.
Here is the code snippet:
def generator_step(args, device, planes, netG, optG, scalerG, discriminators, pos_coeff, active_plane, x_3d_prev, t):
gen_loss = 0.
netG.zero_grad()
x_3d = denoising_step(args, device, netG, pos_coeff, x_3d_prev, t)
for plane in planes:
if args.discriminator_cycle_length > 0 and plane != active_plane: # type: ignore
continue
cur = extract_plane_slices(x_3d, plane, args.slices_per_axis)
prev = extract_plane_slices(x_3d_prev, plane, args.slices_per_axis)
t_exp = t.unsqueeze(1).repeat(1, args.slices_per_axis).view(-1)
out_fake_G = discriminators[plane](cur, t_exp, prev).view(-1)
gen_loss += F.softplus(-out_fake_G).mean() / len(planes) if args.discriminator_cycle_length <= 0 else F.softplus(-out_fake_G).mean()
scalerG.scale(gen_loss).backward()
scalerG.step(optG)
scalerG.update()
if args.ema:
optG.swap_parameters_with_ema(store_params_in_ema=True)
return x_3d,gen_loss
def denoising_step(args, device, netG, pos_coeff, x_3d_prev, t):
z = torch.randn(args.batch_size, args.nz, device=device)
with autocast(dtype=torch.float16):
x0 = netG(x_3d_prev, t, z)
x_3d = sample_posterior(pos_coeff, x0, x_3d_prev, t).detach()
return x_3d