I have these two functions, one written in pytorch and the other written in Jax. The one written in pytorch kills my kernel on my linux laptop (16 gb ram) and also on kills the kernel (out of memory) on my colab pro plus as soon as I set the num_epochs above 1e5. This happens even when I move device to GPU. Meanwhile the one written in jax runs without any issues even when I set the num_epochs to 1e6. I would like to know why this happens because I have a program I really wanna write using pytorch. Thank you
Pytorch
def fit_stochastic(self,learning_rate:float, num_epochs:int):
self.learning_rate = learning_rate
self.num_epochs = int(num_epochs)
self.par_log = self.convert_to_internal(self.p0).requires_grad_(True)
optimizer = torch.optim.Adam(params=[self.par_log], lr = learning_rate)
self.losses = []
for epoch in range(self.num_epochs):
optimizer.zero_grad()
#loss = get_loss(param,f_data,Z_data)
self.loss = self.cost_func(self.par_log, self.F, self.Z, self.Zerr, self.lb_mat, self.ub_mat, self.smf)
if epoch%int(self.num_epochs/10)==0:
print("" + str(epoch) + ": "
+ "loss=" + "{:5.3e}".format(self.loss)
)
self.losses.append(self.loss.clone())
self.loss.backward()
optimizer.step()
self.popt = self.convert_to_external(self.par_log)
self.perr = self.compute_perr(self.popt, self.F, self.Z, self.Zerr)
self.chi_sqr = torch.mean(functorch.vmap(self.wrms_func, in_dims=1)(self.popt, self.F, self.Z, self.Zerr))
self.aic = torch.mean(functorch.vmap(self.compute_aic, in_dims=1)(self.popt, self.F, self.Z, self.Zerr))
return self.popt, self.perr, self.chi_sqr, self.aic
Jax
def fit_stochastic(self, learning_rate, num_epochs):
self.learning_rate = learning_rate
self.par_log = self.convert_to_internal(self.p0)
self.opt_init, self.opt_update, self.get_params = jax_opt.adam(learning_rate)
self.opt_state = self.opt_init(self.par_log)
self.num_epochs = int(num_epochs)
# Timing
from datetime import datetime
start = datetime.now()
self.loss_history = []
for epoch in range(self.num_epochs):
self.loss, self.opt_state = jax.jit(self.train_step)(epoch, self.opt_state, self.F, self.Z, \
self.Zerr, self.lb_mat, self.ub_mat, self.smf)
self.loss_history.append(float(self.loss))
if epoch%int(self.num_epochs/10)==0:
print("" + str(epoch) + ": "
+ "loss=" + "{:5.3e}".format(self.loss)
)
self.popt = self.convert_to_external(self.get_params(self.opt_state))
self.perr, self.corr = self.compute_perr(self.popt, self.F, self.Z, self.Zerr)
self.chi_sqr = jnp.mean(jax.vmap(self.wrms_func, in_axes=1)(self.popt, self.F, self.Z, self.Zerr))
self.aic = jnp.mean(jax.vmap(self.compute_aic, in_axes=1)(self.popt, self.F, self.Z, self.Zerr))
end = datetime.now()
print(f"total time is {end-start}", end=" ")
return self.popt, self.perr, self.corr, self.chi_sqr, self.aic