I am trying to run for loops inside each training iteration(part of my code, is attached below). What I would like to achieve is that for each iith and ith loop i have fresh tensors with no grads, as at end of the loop I save the output to vector, and before the next iteration(epoch), these are backpropagated,.
First of all I do not know how to make the tensor and all intermediary tensors in each of the for loops to be gradient free, after each iith loop ends. So I know there is a problem there. Second if I run this code, it runs through the first epoch and I presume it goes on through the second, but did not backpropagate the second time. I get the error “Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().”
Please note that I have not added the model here, as the model function works ok, if I do not have these two for loops inside the epoch(iterations). For this loop I also used retain_graph=True, it give me in_place variable error. But I do not want to use this option, and is more interested in what is happening. I appreciate assistance.
for epoch in range(epochs): for ii in range(seis_forward.shape):# for all shots for i in range(tmmax_all_times.shape): time_step=tmmax_all_times[:,:,i] time_step=torch.flatten(time_step) time_step=nxx.reshape(len(time_step),1) net_Input_tensor = torch.cat((vp_new,nxx,nzz,time_step),dim=1) x=net_Input_tensor u,u_xx,u_zz,u_tt,vel_pred=solve_pde_pinn_fwi(x, dnn_Linear.train(),vel_in) #loss PDE u=u.reshape(self.v0.T.shape) u_xx=u_xx.reshape(self.v0.T.shape) u_zz=u_zz.reshape(self.v0.T.shape) u_tt=u_tt.reshape(self.v0.T.shape) u_train[:,:,i,ii]=u u_xx_train[:,:,i,ii]=u_xx u_zz_train[:,:,i,ii]=u_zz u_tt_train[:,:,i,ii]=u_tt seis_pred[:,:,ii]=u[receivers_depth,first_rcvr_point-1:last_rcvr_point] vpp=vel_pred.reshape(self.v0.shape).T wave_analytic_from_nn=u_tt-((vpp)**2)*(u_xx+u_zz) wave_analytic_from_nn_all[:,:,i,ii]= wave_analytic_from_nn PDE_Loss=loss(wave_analytic_from_nn_all,u_nn_zeros) seismic_loss=loss(seis_pred,seis_forward) Total_loss=seismic_loss + PDE_Loss optimizer.zero_grad() Total_loss.backward()