Autograd.grad computation time becomes longer and longer

Hi dear community, I want to post a question about autograd.

Here are my functions:

def gradients(net, inputdata):
    inputs_res = inputdata.clone()
    outputs_res = net(inputs_res)
    gradients = autograd.grad(outputs=outputs_res, inputs=inputs_res,
                              grad_outputs=torch.ones(outputs_res.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    return gradients

def nn_sym4(x0, y0, h, net):
    gamma1 = torch.Tensor(np.asarray(1 / (2 - 2**(1/3)))).reshape(-1,1)
    gamma2 = torch.Tensor(1 - 2 * gamma1)
    x, y = x0, y0
    x, y = nn_verlet(x, y, h * gamma1, net)
    x, y = nn_verlet(x, y, h * gamma2, net)
    x, y = nn_verlet(x, y, h * gamma1, net)
    return x, y

def nn_verlet(x0, y0, h, net):
    a = time.time()
    xy0 = torch.cat((x0, y0), dim = 1)
    b = time.time()
    print('cat1:',b-a)
    retx = xy0[:,0:2].reshape(-1,2) +  (gradients(net, xy0)[:,2:4].reshape(-1,2))* h / 2.0
    c = time.time()
    print('grad1:',c-b)
    pv1 = torch.cat((retx, xy0[:,2:4].reshape(-1,2)), dim = 1)
    d = time.time()
    print('cat2:',d-c)
    rety = xy0[:,2:4].reshape(-1,2) - (gradients(net, pv1)[:,0:2].reshape(-1,2)) * h
    e = time.time()
    print('grad2:',e-d)
    pv2 = torch.cat((retx, rety), dim = 1)
    f = time.time()
    print('cat3:',f-e)
    retx = retx + (gradients(net, pv2)[:,2:4].reshape(-1,2)) * h / 2.0
    g = time.time()
    print('grad3:',g-f)
    return retx, rety

def generate_data(neth, pos_t0, vel_t0, batch_size, NT):
    pos, vel = autograd.Variable(torch.Tensor(pos_t0.clone()), requires_grad=True),autograd.Variable(torch.Tensor(vel_t0.clone()), requires_grad=True)
    for t in range(NT):
        pos, vel= nn_sym4(pos, vel, dt, neth)
    return pos_g, vel_g

When I run generate_data(), I found it really slow to compute “nn_sym4()”, because as the number of iteration grows, the function gradients() needs more and more time. I guess it is because I use too many autograd()? So the graph is accumulated to bigger and bigger? Or is here anyway to make the process quicker. I tried to release the graph but failed, and make grad as 0 doesn’t work either.

Thanks!

I guess this might be the case as you are using retain_graph=True in your code.
Could you explain a bit more why this is needed and format your code by wrapping it into three backticks ```?

Thank you for your reply sir. I tried, if I set retain_graph = False, it says Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

This error would be raised as it seems you are reusing the output as a new input in this loop:

    for t in range(NT):
        pos, vel= nn_sym4(pos, vel, dt, neth)

and are thus increasing the computation graph.
Is this desired? If so, then the slow down in the gradient computation would be expected as you would need to calculate more gradients in each iteration.
If not, you might want to detach() the inputs to nn_sym4 and remove the retain_graph=True argument.

Yes, I need to iterate this for NT times, and I need to keep all values of pos and vel as my generated data set…

Thank you sir! I solved it, it was truely the problem you indicated, I detach() the inputs of nn_sym4(), but set requires_grad=True in nn_verlet before using gradients().

1 Like