I’m working with Physics Informed Neural Networks and establish the physics by computing the derivatives with the torch.autograd.grad() function. I only calculate the first order derivatives in my framework. After that I try to compute the total loss and call the loss.backward() method on the total loss, which computes the gradients of the loss with respect to all the tensors that have the requires_grad property set to true (involves the weights and biases and also the network inputs with which the network outputs are computed). Am I right?
My question is that what should be the values for options retain_graph and create_graph in the torch.autograd.grad function and how does these values affect the backward() call? If it affects, how does it affect? Can anyone shed light on the same please?
Do you have a minimal reproducible example for your problem?
If you want to compute higher-order gradients, you need to make sure retain_graph=True
and create_graph=True
as PyTorch will automatically free the gradients of previous operations (unless you tell it otherwise).
def PDE_loss(self, pred, target):
xy = target
u, v, p = pred[:, 0:1], pred[:, 1:2], pred[:, 2:3]
#first order derivative
u_x = torch.autograd.grad(u, xy, torch.ones([u.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
u_y = torch.autograd.grad(u, xy, torch.ones([u.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
v_x = torch.autograd.grad(v, xy, torch.ones([v.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
v_y = torch.autograd.grad(v, xy, torch.ones([v.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
p_x = torch.autograd.grad(p, xy, torch.ones([p.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
p_y = torch.autograd.grad(p, xy, torch.ones([p.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
#second order derivative
u_xx = torch.autograd.grad(u_x, xy, torch.ones([u_x.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
u_yy = torch.autograd.grad(u_y, xy, torch.ones([u_y.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
v_xx = torch.autograd.grad(v_x, xy, torch.ones([v_x.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
v_yy = torch.autograd.grad(v_y, xy, torch.ones([v_y.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
#continous equation
f0 = u_x + v_y
#momentum equation
f1 = u*u_x + v*u_y + p_x - 1/self.Re*(u_xx + u_yy)
f2 = u*v_x + v*v_y + p_y - 1/self.Re*(v_xx + v_yy)
mse_f0 = self.MSE_loss(f0, torch.zeros_like(f0))
mse_f1 = self.MSE_loss(f1, torch.zeros_like(f1))
mse_f2 = self.MSE_loss(f2, torch.zeros_like(f2))
return mse_f0 + mse_f1 + mse_f2
def closure(self):
total_batch_loss = 0.0
col_batch_loss = 0.0
bc_batch_loss = 0.0
outlet_batch_loss = 0.0
self.model.train()
[col_data, bc_data, outlet_data] = self.data[0], self.data[1], self.data[2]
for i, (col_data, bc_data, outlet_data) in enumerate(zip_longest(col_data, bc_data, outlet_data)):
total_loss = torch.tensor(0.0).to(device)
self.optimizer.zero_grad()
#unpack data
col_xy = col_data
bc_xy, bc_uv = bc_data[0], bc_data[1]#
out_xy, out_p = outlet_data[0], outlet_data[1]
#forward pass
if col_xy is not None:
col_xy = col_xy.clone()
col_xy.requires_grad = True
col_pred = self.model(col_xy)
col_loss = self.PDE_loss(col_pred, col_xy)
total_loss += col_loss
if bc_xy is not None:
bc_pred = self.model(bc_xy)
bc_loss = self.BC_loss(bc_pred, bc_uv)
total_loss += bc_loss
if out_xy is not None:
out_pred = self.model(out_xy)
outlet_loss = self.Outlet_loss(out_pred, out_p)
total_loss += outlet_loss
total_loss.backward()
if self.optimizer.__class__.__name__ != "LBFGS":
self.optimizer.step()
total_batch_loss += total_loss.detach().cpu()
col_batch_loss += col_loss.detach().cpu()
bc_batch_loss += bc_loss.detach().cpu()
outlet_batch_loss += outlet_loss.detach().cpu()
self.losses["total"].append(total_batch_loss)
self.losses["pde"].append(col_batch_loss)
self.losses["bc"].append(bc_batch_loss)
self.losses["outlet"].append(outlet_batch_loss)
In the code above, I don’t see a rationale behind using the retain graph and create graph when calculating the second order gradients. I understand that these options need to be switched on when computing the first order gradients, to be utilized when computing the second order gradients.
When computing second-order gradients, you’ll need make sure the first-order gradient call has retain_graph=True
and create_graph=True
as you mentioned. You won’t need them for the second-order gradients as you’re not computing any higher-order gradients.
The docs of torch.autograd.grad
can be found here: torch.autograd.grad — PyTorch 2.3 documentation
When using torch.autograd.grad
it will return a Tensor, which represents the gradient of outputs
with respect to the inputs
. When calling .backward()
it will compute gradients in an accumulated fashion on all Tensors that have a .grad
attribute. This is why torch.optim.Optim
objects, e.g. torch.optim.Adam
, require you to call optim.zero_grad()
before computing loss.backward()
, because if you don’t the .grad
attribute of your tensors will have the gradient of the current epoch and the previous epoch.
Thanks a ton for your timely support. Also, I had this another doubt, why do we require to use the retain_graph
in the loss.backward(retain_graph = True)
call. What is the rationale behind doing this?
I don’t think you need to have retain_graph=True
within loss.backward()
, if the interpreter is throwing an error asking you to do loss.backward(retain_graph=True)
, perhaps it might be due to mini-batching your gradients? But I’m not 100% sure.
Suppose, imagine that i compute the gradient of the output of the network using torch.autograd.grad()
, with respect to the input, to design the PDE residual loss. And then when I call the loss.backward()
, does the gradients get accumulated in the tensors that result from the torch.autograd.grad()
call? If so, do I need to use the create_graph = True
option in the torch.autograd.grad()
call?
If I understand you correctly, you’re asking if the .grad
attribute is populated on the PDE residual loss terms (which are calculated via the torch.autograd.grad
calls). Would this be the u_xx
, u_yy
etc. terms?
I don’t think this would be an issue as the only gradient you care about is with respect to the parameters, and you could check if these intermediate have .grad
attributes by calling u_xx.grad
?
You’re performing higher-order derivatives, so I think you should need to keep create_graph=True
. If you’re unsure, you can always run your code with create_graph=True
and create_graph=False
and see if you get the same results.