Module design: optimizer and loss inside module

We’re working on a fairly big model that has a lot of “plug-in” components (that I’ll call submodules), i.e. smaller models that are trained on intermediate results of the big model during the forward pass. Each of these submodules is trained at the same time as the big model, through an individual optimizer and loss function for each. Submodules can be individually enabled or disabled.

The problem we have is that the individual optimizers and loss functions are starting to clutter the main training loop code. For each submodule there needs to be a check whether that module is activated or not, a call to the backward pass, a call to the optimizer step function and a call to zero grad. Only the forward pass is made in a single call, as calling the forward pass on the big model “cascades” into calling the forward pass of each submodule at some point.

Furthermore, to compute the loss for each submodule the ground truth tensors and the predicted tensors of the submodules must make their way to the loss function somehow. As we only have a single loss function that computes it for each submodule, all these tensors end up in the main training loop so as to be used by that loss function. This is clearly not scalable design.

Does torch have an inbuilt mechanism to deal with this situation? I was thinking of refactoring the code in a way that these submodules would have their own individual optimizer and loss function. To avoid having to cache the result of each submodule’s forward function, I was thinking of having a step function that would make the forward-loss-backward pass operations in one go, and an optim_step function that calls the optimizer step and then zero grad. As the step function is called on the big model, it is also called on every submodule in a cascading fashion. The same goes for the optim step.

Your idea makes generally sense.
One thing I would worry about is, if you are planning on using (Distributed)DataParallel at some point.
I’ve seen a lot of errors trying to just wrap the model into DDP, if there were a lot of these “convenience functions”.

One thing I would worry about is, if you are planning on using (Distributed)DataParallel at some point.
I’ve seen a lot of errors trying to just wrap the model into DDP , if there were a lot of these “convenience functions”.

That is precisely what I’m dealing with at the moment indeed: the main module is wrapped in DP and I’ve had to disable that for now. Hope I’ll find a nice way to come around it