I’m working on the following model for backwards stochastic differential equations. I have time series x_n that I know for all times n = 0, …, N and a time series y_ for which I know only y_N. The update rule for y is a function of y,x and auxiliary process z at the previous time. Idea is that one does iterration backward in time and at each time increment n trains a network with x and n as inputs, that outputs y_n and z_n that minimizes the distance to the previous value of y_n+1 under the update rule. Note: I use a single network that outputs [y,z] and then unpack them (look at the forward method in Model).
One of the tips for this type of algorithm is that one uses weights from the previously learned net at each timestep to reduce the learning time. I do this by saving and loading state_dict().
So basically I want to use a single NN multiple times, using previously trained weights. The problem I have is that eventhoug I reinitialize the model at each time step I still get an error: “Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed) …”. Then, when I add “retain_graph=True” to my loss function, I get an error: “one of the variables needed for gradient computation has been modified by an inplace operation …”. I tried adding .copy() at various places in my code but it doesn’t, seem to work. I’m confused, since I reinitialize class BSDE_solver, that has Model as an attribute, at each iteration.
Here are the three important classes in my code namely Model(standard pytorch class with forward method), BSDEsolver(class that generates the input data, and traines a network at one time step) and BSDEitr(initilaizes BSDE_solver and uses it to iterate over time steps n).
class Model(nn.Module): def __init__(self, equation, dim_h): super(Model, self).__init__() self.linear1 = nn.Linear(equation.dim_x + 1, dim_h) self.linear2 = nn.Linear(dim_h, dim_h) self.linear3 = nn.Linear(dim_h, equation.dim_y + equation.dim_y * equation.dim_d) self.equation = equation def forward(self, batch_size, N, n, x): def phi(x): x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) return self.linear3(x) #[bs,(dy*dd)] -> [bs,dy,dd] delta_t = self.equation.T / N u = torch.cat((x, torch.ones(x.size(), 1, device=device) * delta_t * n), 1) yz = phi(u) y = yz[:,:self.equation.dim_y].clone() z = yz[:,self.equation.dim_y:].reshape(-1,self.equation.dim_y, self.equation.dim_d).clone() return y,z class BSDEsolver(): def __init__(self, equation, dim_h): self.model = Model(equation, dim_h).to(device) self.equation = equation self.optimizer = torch.optim.Adam(self.model.parameters()) def loss(self, x, n, y_target, y, z, w, N): if n == N: dist = (y - self.equation.g(x)).norm(2,dim=1) else: delta_t = self.equation.T / N estimate = y - self.equation.f(delta_t*n, x ,y, z)*delta_t + torch.matmul(z, w).reshape(-1, self.equation.dim_y) dist = (y_target - estimate).norm(2,dim=1) return torch.mean(dist) def gen_bm(self, batch_size, N): delta_t = self.equation.T / N W = torch.randn(batch_size, self.equation.dim_d, N, device=device) * np.sqrt(delta_t) return W def gen_forward(self, batch_size, N, W): delta_t = self.equation.T / N x = self.equation.x_0 + torch.zeros(batch_size, N * self.equation.dim_x, device=device).reshape(-1,self.equation.dim_x, N) #[bs,dx,N] for i in range(N-1): w = W[:, :, i].reshape(-1, self.equation.dim_d, 1) x[:,:,i+1] = x[:,:,i] + self.equation.b(delta_t * i, x[:,:,i]) * delta_t + torch.matmul(self.equation.sigma(delta_t * i, x[:,:,i]),w).reshape(-1, self.equation.dim_x) return x def train(self, batch_size, N,n,y_prev, itr): loss_n =  for i in range(itr): W = self.gen_bm(batch_size, N) x = self.gen_forward(batch_size, N, W) x = x[:,:,n] w = W[:, :, n].reshape(-1, self.equation.dim_d, 1) y,z = self.model(batch_size, N, n, x) loss = self.loss(x,n,y_prev,y,z,w,N) self.optimizer.zero_grad() loss.backward(retain_graph=True) self.optimizer.step() loss_n.append(loss) return loss_n, y class BSDEiter(): def __init__(self, equation, dim_h): self.equation = equation self.dim_h = dim_h def train_whole(self, batch_size, N, path, itr): loss_data =  y_pred = torch.zeros(batch_size, self.equation.dim_y, N) z_pred =  for n in range(N-1,-1,-1): bsde_solver = BSDEsolver(self.equation, self.dim_h) if n == N-1: y_prev = y_pred[:,:,n] else: bsde_solver.model.load_state_dict(torch.load(path+"state_dict_" + str(n+1)), strict=False) y_prev = y_pred[:,:,n+1] loss_n, y = bsde_solver.train(batch_size, N, n, y_prev, itr) loss_data.append(loss_n) y_pred[:,:,n] = y.clone() torch.save(bsde_solver.model.state_dict(),path+"state_dict_" + str(n)) return loss_data
I would greatly appreciate some help.