I’ve been trying to train a pix2pix model in pytorch as one of my first projects using this framework - I’m a new user, I’ve been using tf/keras before this.
I’m using generator/discriminator network definition code from this repository and similar training code.
I’m encountering a problem with some weights for my generator network becoming inf
immediately during training when I call the optimizer’s step()
function in the very first training iteration.
If I set torch.autograd.set_detect_anomaly(True)
I will get the error message: RuntimeError: Function 'CudnnConvolutionBackward' returned nan values in its 0th output.
I’m training on images loaded as half-precision floats and scaled to [0, 1] by dividing pixels by 255.
Before training, I ran a test input through the generator and discriminator to check that all was as expected with randomly initialized weights. Both generator and discriminator output sound outputs before training - no nan
or inf
- it is only after the first call of the generator optimizer’s step()
.
My problem is similar to what’s described in this thread, except I didn’t find answers there.
Here is my training code:
# G: U-net generator model, D: PatchGAN discriminator
G = models.GeneratorUNet()
D = models.Discriminator()
models.init_weights(G) # models: external import, initialize weights with normal(0.0, 0.02)
models.init_weights(D)
adv_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()
pixelwise_weight = 100
assert torch.cuda.is_available()
G = G.half().cuda()
D = D.half().cuda()
adv_loss.cuda()
pixelwise_loss.cuda()
device = torch.device("cuda:0")
Tensor = torch.cuda.HalfTensor
torch.autograd.set_detect_anomaly(True)
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))
def check_state(model):
for key, val in model.state_dict().items():
check_numeric(key, val)
def check_numeric(name, val):
if val.isnan().any():
raise ValueError(f'Found nan values in network element {name}:\n{val}')
if val.isinf().any():
raise ValueError(f'Found inf values in network element {name}:\n{val}')
def check_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
if not torch.isfinite(param.grad).all():
raise ValueError(f'{name}has a non-finite gradient')
else:
warnings.warn(f'{name} if None')
n_epoch = 5
for epoch_i in range(n_epoch):
for batch_i, batch in enumerate(ds_iterator):
inp, tar = prepare_batch(batch)
check_numeric('batch-input', inp); check_numeric('batch-target', tar)
# PatchGAN discriminator labels plus some noise
label_tar = Tensor(Variable(
torch.ones(d_output_shape, requires_grad=False, device=device, dtype=torch.float16) + \
0.05 * torch.rand(d_output_shape, device=device, dtype=torch.float16, requires_grad=False)))
label_generated = Tensor(Variable(
torch.zeros(d_output_shape, requires_grad=False, device=device, dtype=torch.float16) + \
0.05 * torch.rand(d_output_shape, device=device, dtype=torch.float16, requires_grad=False)))
# Generator train step
optimizer_G.zero_grad()
generated = G(inp)
d_fake_prediction = D(inp, generated)
check_numeric('discriminator prediction', d_fake_prediction)
print(d_fake_prediction)
g_adv_loss = adv_loss(d_fake_prediction, label_tar)
check_numeric('Generator adversarial loss', g_adv_loss)
g_pixelwise_loss = pixelwise_loss(generated, tar)
check_numeric('Generator pixelwise loss', g_pixelwise_loss)
g_tot_loss = g_adv_loss + pixelwise_loss_weight * g_pixelwise_loss
check_numeric('Generator loss', g_tot_loss)
g_tot_loss.backward()
# ==== Update generator - the problematic step
check_gradients(G)
check_state(G)
optimizer_G.step()
check_state(G)
# Discriminator train step
optimizer_D.zero_grad()
d_inp_prediction = D(inp, tar)
d_real_loss = adv_loss(d_inp_prediction, label_tar)
d_genenerated_prediction = D(inp, generated.detach())
d_generated_loss = adv_loss(d_genenerated_prediction, label_generated)
d_tot_loss = d_generated_loss + d_real_loss
d_tot_loss.backward()
check.check_gradients(D)
check_state(D)
optimizer_D.step()
check_state(D)
The check_state(G)
function finds inf values in the generator’s down1.model.0
weights, which is the very first Conv2d layer in the U-net, immediately after the update:
--------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-13-08a1061796d8> in <module>()
45 check_state(G)
46 optimizer_G.step()
---> 47 check_state(G)
48
49
<ipython-input-12-53bce0ba077d> in check_state(model)
1 def check_state(model):
2 for key, val in model.state_dict().items():
----> 3 check_numeric(key, val)
4
5
<ipython-input-12-53bce0ba077d> in check_numeric(name, val)
8 raise ValueError(f'Found nan values in network element {name}:\n{val}')
9 if val.isinf().any():
---> 10 raise ValueError(f'Found inf values in network element {name}:\n{val}')
11
12 def check_gradients(model):
ValueError: Found inf values in network element down1.model.0.weight:
tensor([[[[-0.0071, 0.0104, -0.0214, -0.0032],
[ 0.0017, -0.0107, 0.0184, -0.0110],
[ 0.0123, -0.0030, 0.0021, 0.0219],
[ 0.0100, -0.0283, 0.0221, -0.0306]]],
[[[ 0.0104, -0.0177, -0.0127, 0.0025],
[ 0.0034, 0.0259, -0.0134, -0.0453],
[-0.0383, -0.0414, -0.0170, -0.0095],
[-0.0069, 0.0054, 0.0053, -0.0057]]],
[[[-0.0145, 0.0374, 0.0151, inf],
[ 0.0081, -0.0072, 0.0039, -0.0157],
[-0.0076, 0.0373, 0.0077, -0.0062],
[ 0.0081, 0.0218, 0.0273, -0.0279]]],
...,
[[[-0.0335, -0.0025, -0.0200, -0.0047],
[-0.0007, 0.0229, 0.0079, 0.0400],
[-0.0125, 0.0203, -0.0186, -0.0202],
[-0.0139, 0.0313, -0.0167, 0.0044]]],
[[[-0.0120, -0.0248, 0.0106, 0.0044],
[-0.0176, -0.0036, -0.0154, 0.0140],
[ 0.0333, -0.0048, 0.0269, 0.0060],
[ 0.0170, -0.0023, 0.0054, 0.0044]]],
[[[ 0.0206, -0.0218, -0.0184, -0.0201],
[ 0.0319, -0.0241, 0.0281, 0.0074],
[ 0.0211, -0.0055, 0.0014, -0.0230],
[-0.0184, 0.0142, 0.0044, 0.0003]]]], device='cuda:0',
dtype=torch.float16)
Here is what’s in down1.model.0
:
GeneratorUNet(
(down1): UNetDown(
(model): Sequential(
(0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
)
( ... )
I logged the norms of the generator’s gradients using the following code:
g_gradient_norms = sorted([param.grad.data.norm(2).item()
for param in G.parameters()
if param.grad is not None])
and found that the largest gradient l2-norm was ~5.0 before the update.
I am hoping that some people with more experience with pytorch might be able to see something I’m missing here, because I haven’t been able to figure out what’s causing inf
to show up and break my network. Underflow errors due to gradients being too small? Numerical foot-guns I should be aware of?