I was facing a problem of getting negative KLs or very large KLs times to times during training of my model. So I use the cooper library to impose some constraints on the KL terms of the components of mixture of experts model. Here is my constraint class
import cooper
class PenalizedKLConstraint(cooper.ConstrainedMinimizationProblem):
"""
Class for KL-penalized constrains.
"""
def __init__(self, module: nn.Module, l1_lambda: float= 0.0001, mean_constraint:float = 2.5, device: torch.types.Device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
self.mean_constraint = mean_constraint
self._module = module
self.l1_lambda = l1_lambda
self.device = device
self._module.to(self.device)
super(PenalizedKLConstraint, self).__init__(is_constrained=True)
def compute_losses(
self,
model: nn.Module,
context: torch.Tensor,
another_context:torch.Tensor,
precision: torch.Tensor,
weights: torch.Tensor,
):
model.to(self.device)
with torch.no_grad():
samples = model.sample(context.to(self.device))
losses = - torch.squeeze(self._module(torch.cat([context.to(self.device), samples], dim=-1), inTrain=False))
kls = model.kls_other_chol_inv(context.to(self.device), another_context.to(self.device), precision.to(self.device))
#Adding L1 regularization to prevent gradient explosion
l1_norm = torch.mul(self.l1_lambda , sum(p.abs().sum() for p in model.parameters())).to(self.device)
loss = torch.mean(weights * (losses + kls)).to(self.device) + l1_norm + model._regularizer
return loss, kls, l1_norm
def closure(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
p:torch.Tensor,
w:torch.Tensor,
model: nn.Module,
) -> cooper.CMPState:
"""
Computes CMP state for the given model and inputs.
"""
# Compute loss and regularization statistics
loss, kls, l1_norm = self.compute_losses(model, inputs, targets, p, w)
misc = {"l1_norm_vlue": l1_norm, "kl_value": kls}
# Entries of kls >= 0 (equiv. -kls <= 0) & kls - costraint <= 0
ineq_defect = torch.stack([-kls, kls - self.mean_constraint])
# Store model output and other 'miscellaneous' objects in misc dict
state = cooper.CMPState(loss=loss, ineq_defect=ineq_defect , misc=misc)
return state
Then a snippet of my main class which is very long, here I post where I initialized the model with constraint
class MixtureEIM:
@staticmethod
def get_default_config() -> ConfigDict:
c = ConfigDict(
mean_constraint=10,
checkpoint_dir=os. getcwd()+ "/EIM_out_imgs/"
)
c.finalize_adding()
return c
def __init__(self, config: ConfigDict, context_dim: int, sample_dim: int, seed: int = 0):
self.checkpoint_dir= self.c.checkpoint_dir
self._mean_constraint= self.c.mean_constraint
self._cmp = PenalizedKLConstraint(self._dre, l1_lambda = 0.0001, mean_constraint=self._mean_constraint, device=device)
self._formulation= cooper.LagrangianFormulation(self._cmp)
self._model.to(device)
def _components_train_step(self, importance_weights, old_means, old_chol_precisions, save_model=False):
for i in range(self._model.num_components):
dataset = torch.utils.data.TensorDataset(self._train_contexts, importance_weights[:, i], old_means, old_chol_precisions)
loader = torch.utils.data.DataLoader( dataset, shuffle = True, batch_size=self.c.components_batch_size)
for batch_idx, (context_batch, iw_batch, old_means_batch, old_chol_precisions_batch) in enumerate(loader):
iw_batch = iw_batch / torch.sum(iw_batch)
primal_optimizer = cooper.optim.ExtraSGD(self._model.components[i].trainable_variables, lr=self.c.components_learning_rate, momentum=0.7)
dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=1e-3, momentum=0.7)
# Wrap the formulation and both optimizers inside a ConstrainedOptimizer
coop = cooper.ConstrainedOptimizer(formulation=self._formulation,
primal_optimizer=primal_optimizer,
dual_optimizer=dual_optimizer,
)
coop.zero_grad()
lagrangian = self._formulation.composite_objective(self._cmp.closure, context_batch, old_means_batch[:, i], old_chol_precisions_batch[:, i], iw_batch, self._model.components[i])
self._formulation.custom_backward(lagrangian)
coop.step(self._cmp.closure, context_batch, old_means_batch[:, i], old_chol_precisions_batch[:, i], iw_batch, self._model.components[i])
loss = self._cmp.closure(model, context_batch, old_means_batch[:, i], old_chol_precisions_batch[:, i], iw_batch, self._model.components[i]).loss
When I ran the above code on a remote machine using slurm script, I got this error message:
what(): Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Exception raised from unpack at /tmp/coulombc/pytorch_build_2021-11-09_14-57-01/avx2/python3.8/pytorch/torch/csrc/autograd/saved_variable.cpp:122 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x55 (0x2ae6de115905 in /home/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0xd6 (0x2ae6de0f72a9 in /home/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: torch::autograd::SavedVariable::unpack(std::shared_ptr<torch::autograd::Node>) const + 0x13d2 (0x2ae6bb4e6bb2 in /home/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
/var/spool/slurmd/job32471479/slurm_script: line 23: 84798 Aborted (core dumped) python EIMn.py
I am not sure where the computational graph got repeated which has caused me to get this error?
I will appreciate if someone can help. Thanks.
P.S. I can not use retain_graph=True for the backward inside the PenalizedKLConstraint since it is not defined.