Runing a model multiple times, using previously trained weights

I’m working on the following model for backwards stochastic differential equations. I have time series x_n that I know for all times n = 0, …, N and a time series y_ for which I know only y_N. The update rule for y is a function of y,x and auxiliary process z at the previous time. Idea is that one does iterration backward in time and at each time increment n trains a network with x and n as inputs, that outputs y_n and z_n that minimizes the distance to the previous value of y_n+1 under the update rule. Note: I use a single network that outputs [y,z] and then unpack them (look at the forward method in Model).

One of the tips for this type of algorithm is that one uses weights from the previously learned net at each timestep to reduce the learning time. I do this by saving and loading state_dict().

So basically I want to use a single NN multiple times, using previously trained weights. The problem I have is that eventhoug I reinitialize the model at each time step I still get an error: “Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed) …”. Then, when I add “retain_graph=True” to my loss function, I get an error: “one of the variables needed for gradient computation has been modified by an inplace operation …”. I tried adding .copy() at various places in my code but it doesn’t, seem to work. I’m confused, since I reinitialize class BSDE_solver, that has Model as an attribute, at each iteration.

Here are the three important classes in my code namely Model(standard pytorch class with forward method), BSDEsolver(class that generates the input data, and traines a network at one time step) and BSDEitr(initilaizes BSDE_solver and uses it to iterate over time steps n).


class Model(nn.Module):
    def __init__(self, equation, dim_h):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(equation.dim_x + 1, dim_h)
        self.linear2 = nn.Linear(dim_h, dim_h)
        self.linear3 = nn.Linear(dim_h, equation.dim_y + equation.dim_y * equation.dim_d)


        self.equation = equation

    def forward(self, batch_size, N, n, x):
        def phi(x):
            x = F.relu(self.linear1(x))
            x = F.relu(self.linear2(x))
            return self.linear3(x) #[bs,(dy*dd)] -> [bs,dy,dd]

        delta_t = self.equation.T / N

        u = torch.cat((x, torch.ones(x.size()[0], 1, device=device) * delta_t * n), 1)
        yz = phi(u)
        y = yz[:,:self.equation.dim_y].clone()
        z = yz[:,self.equation.dim_y:].reshape(-1,self.equation.dim_y, self.equation.dim_d).clone()
        return y,z



class BSDEsolver():
    def __init__(self, equation, dim_h):
        self.model = Model(equation, dim_h).to(device)
        self.equation = equation
        self.optimizer = torch.optim.Adam(self.model.parameters())

    def loss(self, x, n, y_target, y, z, w, N):
        if n == N:
            dist = (y - self.equation.g(x)).norm(2,dim=1)
        else:
            delta_t = self.equation.T / N
            estimate = y - self.equation.f(delta_t*n, x ,y, z)*delta_t + torch.matmul(z, w).reshape(-1, self.equation.dim_y)
            dist = (y_target - estimate).norm(2,dim=1)
        return torch.mean(dist)



    def gen_bm(self, batch_size, N):
        delta_t = self.equation.T / N
        W = torch.randn(batch_size, self.equation.dim_d, N, device=device) * np.sqrt(delta_t)

        return W

    def gen_forward(self, batch_size, N, W):
        delta_t = self.equation.T / N
        x = self.equation.x_0 + torch.zeros(batch_size, N * self.equation.dim_x, device=device).reshape(-1,self.equation.dim_x, N) #[bs,dx,N]
        for i in range(N-1):
            w = W[:, :, i].reshape(-1, self.equation.dim_d, 1)
            x[:,:,i+1] = x[:,:,i] + self.equation.b(delta_t * i, x[:,:,i]) * delta_t + torch.matmul(self.equation.sigma(delta_t * i, x[:,:,i]),w).reshape(-1, self.equation.dim_x)
        return x

    def train(self, batch_size, N,n,y_prev, itr):
        loss_n = []

        for i in range(itr):
            W = self.gen_bm(batch_size, N)
            x = self.gen_forward(batch_size, N, W)
            x = x[:,:,n]
            w = W[:, :, n].reshape(-1, self.equation.dim_d, 1)

            y,z = self.model(batch_size, N, n, x)
            loss = self.loss(x,n,y_prev,y,z,w,N)

            self.optimizer.zero_grad()
            loss.backward(retain_graph=True)
            self.optimizer.step()
            loss_n.append(loss)

        return loss_n, y

class BSDEiter():
    def __init__(self, equation, dim_h):
        self.equation = equation
        self.dim_h = dim_h

    def train_whole(self, batch_size, N, path, itr):
        loss_data = []
        y_pred = torch.zeros(batch_size, self.equation.dim_y, N)
        z_pred = []


        for n in range(N-1,-1,-1):

            bsde_solver = BSDEsolver(self.equation, self.dim_h)
            if n == N-1:
                y_prev = y_pred[:,:,n]
            else:
                bsde_solver.model.load_state_dict(torch.load(path+"state_dict_" + str(n+1)), strict=False)
                y_prev = y_pred[:,:,n+1]
            loss_n, y = bsde_solver.train(batch_size, N, n, y_prev, itr)
            loss_data.append(loss_n)
            y_pred[:,:,n] = y.clone()
            torch.save(bsde_solver.model.state_dict(),path+"state_dict_" + str(n))



        return loss_data

I would greatly appreciate some help.

The issue you’re facing is due to the fact that the same Model instance is being used across multiple time steps, and the computation graph from previous iterations is being retained. When you use retain_graph=True, you’re forcing PyTorch to keep the graph even after the backward pass, which may lead to memory issues and other errors.

Instead of reusing the same model instance, you should create a new instance of the Model class at each time step. One way to do this is to modify the BSDEiter class’s train_whole method. You can use the previously trained weights by loading the saved state dictionary when creating a new model instance.

Thank you for your reply. But how do I create a new model instance? I thought I already do this when calling bsde_solver = BSDEsolver(self.equation, self.dim_h) at each iterration in train_whole, since BSDEsolver has Model as an attribute.