Keras-like add_loss for torch.nn.Module

I develop Bayesian models using Variational Inference and am coming from a TensorFlow / TensorFlow Probability background. I’ve found the keras loss aggregation pattern where layers call self.add_loss(kl_div) can be very useful for this application domain. It leads to more modular code where any model component can locally add a prior for an intermediate random variable. I’d like to implement a similar pattern for torch.nn.Module. Does anyone have any advice for how best to tackle this?

  • Should I use buffers?
  • Should I just set some private attribute of the module and then aggregate it by traversing all the modules from the top-level module by calling Module.modules()?