I’m using autograd for the first time (as in, I’ve never gone this low level before) and I think I’m doing it wrong. I would appreciate some help on finding what mistake I’m making. The issue is as follows:
A. runs every parameter update. B. runs every 4 parameter updates (in same loop as A) and calls C (where I use autograd.grad). On every iteration of the loop immediately following B being run, moving my training data to my GPU is really slow (as in, it takes multiple seconds). If I disable B entirely, this lag disappears and everything in A runs as fast as you’d expect. There is no noticeable lag during B or during any of the code that follows it, just at the point I’ve marked (moving my batch to gpu).
dataset = ImageFolder(data_path, transform = util.getImToTensorTF()) train_data = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True, pin_memory = True, num_workers = 0) g_opt = torch.optim.AdamW(g.parameters(), lr = LEARNING_RATE * G_PPL_INTERVAL / (G_PPL_INTERVAL + 1), betas = (0, 0.99 ** (G_PPL_INTERVAL / (G_PPL_INTERVAL + 1)))) ... real, _ = next(iter(train_data)) real = real.cuda() # lag here (first cuda op in loop) ... if do_g_ppl: path_batch = BATCH_SIZE // PPL_BATCH path_batch = max(1, path_batch) latent = util.getMixingLatent(path_batch) fake, latents = g(latent) ppl_loss, ppl_norms, w_bar = loss_funcs.PPL(latents, fake, w_bar) g.zero_grad(set_to_none = True) ppl_loss = PPL_WEIGHT * ppl_loss * G_PPL_INTERVAL ppl_loss += 0 * fake[0, 0, 0, 0] # Ties to g output ppl_loss.backward() g_opt.step()
Where the PPL loss is calculated by this function:
# Perpetual path length # Takes input output pair from generator # And a moving average of path length def PPL(latents, gen_img, avg_path_len): n, c, h, w = gen_img.shape y = torch.randn_like(gen_img) / ((h * w)**.5) # Get jacobian * random image wrt input latent vector grad, = autograd.grad((gen_img * y).sum(), latents, create_graph = True) # L2 norm of jacobian, when this is 0, jacobian orthogonal l2norm = torch.sqrt(grad.pow(2).sum(2).mean(1)) # TODO: Do you really need sqrt for l2 norm as a loss term? check computation cost a = avg_path_len + PPL_DECAY * (l2norm.mean() - avg_path_len) # E[(||J^Ty|| - a)^2] loss = (l2norm - a).pow(2).mean() # Return last moving average for next calculation return loss, l2norm, a.detach().item()