Hi,
I am trying to run for loops inside each training iteration(part of my code, is attached below). What I would like to achieve is that for each iith and ith loop i have fresh tensors with no grads, as at end of the loop I save the output to vector, and before the next iteration(epoch), these are backpropagated,.
First of all I do not know how to make the tensor and all intermediary tensors in each of the for loops to be gradient free, after each iith loop ends. So I know there is a problem there. Second if I run this code, it runs through the first epoch and I presume it goes on through the second, but did not backpropagate the second time. I get the error “Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().”
Please note that I have not added the model here, as the model function works ok, if I do not have these two for loops inside the epoch(iterations). For this loop I also used retain_graph=True, it give me in_place variable error. But I do not want to use this option, and is more interested in what is happening. I appreciate assistance.
for epoch in range(epochs):
for ii in range(seis_forward.shape[2]):# for all shots
for i in range(tmmax_all_times.shape[2]):
time_step=tmmax_all_times[:,:,i]
time_step=torch.flatten(time_step)
time_step=nxx.reshape(len(time_step),1)
net_Input_tensor = torch.cat((vp_new,nxx,nzz,time_step),dim=1)
x=net_Input_tensor
u,u_xx,u_zz,u_tt,vel_pred=solve_pde_pinn_fwi(x, dnn_Linear.train(),vel_in)
#loss PDE
u=u.reshape(self.v0.T.shape)
u_xx=u_xx.reshape(self.v0.T.shape)
u_zz=u_zz.reshape(self.v0.T.shape)
u_tt=u_tt.reshape(self.v0.T.shape)
u_train[:,:,i,ii]=u
u_xx_train[:,:,i,ii]=u_xx
u_zz_train[:,:,i,ii]=u_zz
u_tt_train[:,:,i,ii]=u_tt
seis_pred[:,:,ii]=u[receivers_depth,first_rcvr_point-1:last_rcvr_point]
vpp=vel_pred.reshape(self.v0.shape).T
wave_analytic_from_nn=u_tt-((vpp)**2)*(u_xx+u_zz)
wave_analytic_from_nn_all[:,:,i,ii]= wave_analytic_from_nn
PDE_Loss=loss(wave_analytic_from_nn_all,u_nn_zeros)
seismic_loss=loss(seis_pred,seis_forward)
Total_loss=seismic_loss + PDE_Loss
optimizer.zero_grad()
Total_loss.backward()