Kernel dies while running torch.optim.Adam

I have these two functions, one written in pytorch and the other written in Jax. The one written in pytorch kills my kernel on my linux laptop (16 gb ram) and also on kills the kernel (out of memory) on my colab pro plus as soon as I set the num_epochs above 1e5. This happens even when I move device to GPU. Meanwhile the one written in jax runs without any issues even when I set the num_epochs to 1e6. I would like to know why this happens because I have a program I really wanna write using pytorch. Thank you

Pytorch

def fit_stochastic(self,learning_rate:float, num_epochs:int):
        self.learning_rate = learning_rate
        self.num_epochs = int(num_epochs)
        self.par_log = self.convert_to_internal(self.p0).requires_grad_(True)
        optimizer = torch.optim.Adam(params=[self.par_log], lr = learning_rate)
        self.losses = []
        for epoch in range(self.num_epochs):

            optimizer.zero_grad()
            #loss = get_loss(param,f_data,Z_data)
            self.loss = self.cost_func(self.par_log, self.F, self.Z, self.Zerr, self.lb_mat, self.ub_mat, self.smf)

            if epoch%int(self.num_epochs/10)==0:
                print("" + str(epoch) + ": "
                    + "loss=" + "{:5.3e}".format(self.loss)
                ) 
            self.losses.append(self.loss.clone())   
            self.loss.backward()
            optimizer.step()
        self.popt = self.convert_to_external(self.par_log)
        self.perr = self.compute_perr(self.popt, self.F, self.Z, self.Zerr)
        self.chi_sqr = torch.mean(functorch.vmap(self.wrms_func, in_dims=1)(self.popt, self.F, self.Z, self.Zerr))
        self.aic = torch.mean(functorch.vmap(self.compute_aic, in_dims=1)(self.popt, self.F, self.Z, self.Zerr))
        return self.popt, self.perr, self.chi_sqr, self.aic

Jax

    def fit_stochastic(self, learning_rate, num_epochs):
        self.learning_rate = learning_rate
        self.par_log = self.convert_to_internal(self.p0)
        self.opt_init, self.opt_update, self.get_params = jax_opt.adam(learning_rate)
        self.opt_state = self.opt_init(self.par_log)  
        self.num_epochs = int(num_epochs)
        # Timing
        from datetime import datetime
        start = datetime.now()
        self.loss_history = []
        for epoch in range(self.num_epochs):
            self.loss, self.opt_state = jax.jit(self.train_step)(epoch, self.opt_state, self.F, self.Z, \
                                                            self.Zerr, self.lb_mat, self.ub_mat, self.smf)
            self.loss_history.append(float(self.loss))
            if epoch%int(self.num_epochs/10)==0:
                print("" + str(epoch) + ": "
                    + "loss=" + "{:5.3e}".format(self.loss)
                ) 
        self.popt = self.convert_to_external(self.get_params(self.opt_state))
        self.perr, self.corr = self.compute_perr(self.popt, self.F, self.Z, self.Zerr)
        self.chi_sqr = jnp.mean(jax.vmap(self.wrms_func, in_axes=1)(self.popt, self.F, self.Z, self.Zerr))
        self.aic = jnp.mean(jax.vmap(self.compute_aic, in_axes=1)(self.popt, self.F, self.Z, self.Zerr))
        end = datetime.now()
        print(f"total time is {end-start}", end=" ")
        return self.popt, self.perr, self.corr, self.chi_sqr, self.aic

You are appending the loss with the entire computation graph here:

self.losses.append(self.loss.clone())  

which would eventually yield the out of memory issue.
Assuming you want to use self.losses for logging purposes, detach() the loss before appending it.

1 Like