What is the best design pattern to have different training code and inference code in JIT?

For certain models, the forward path might be slightly different in training vs inference.

If I need to export the inference code (but not the training code). What is a good design pattern to achieve this.

@jit.script_method
forward(self, inputs):
   if TRAINING:
       /// training code that does not need to be compiled
   else:
      /// inference code, compile this

Would keeping the forward as a regular method and call the script_method self.forward_inference + tracing work?

Best regards

Thomas

But eventually something needs to have the if statement right?

If I only have this branching code in one module. Eventually the caller also needs to have an if statement for knowing whether it is in training mode

For example

either I write this

ModuleA(jit.ScriptModule):
  @jit.script_method
  def forward(self, inputs):
      self.module_b(inputs)

ModuleB(jit.ScriptModule):
  @jit.script_method 
  def forward(self, inputs):
     if TRAINING:  
         ///
     else: 
        ///

or I write this:

ModuleA(jit.ScriptModule):
  @jit.script_method
  def forward(self, inputs):
      if TRAINING:
         self.module_b.forward_training(inputs)
     else:
         self.module_b.forward_inference(inputs)         


ModuleB(jit.ScriptModule):
   @jit.script_method
   def forward_inference(self, inputs)
   def forward_training(self, inputs):

I don’t want to have to make every single module in my model to have to have forward_training and forward_inference just because one module needs to branch between training and inference

Yes, but if you use self.training, that should be OK, no?

What do you mean by if I use self.training? My problem is that inference_training might contain raw python code that is not exportable. But introducing a branch makes the script_method not exportable

The question is:

Let’s say I have module A and module B. module A uses module B.

module A code is not branchy (inference and training is the same). I want to be able to call module A with just forward. I don’t want module A to have to have a forward_training vs forward_inference.

module B code is branchy (inference / training are different), I am ok for it to have a forward_training and a forward_inference.

Maybe I can do this in the initializer like this:

class ModuleB(jit.ScriptModule):

   def __init__(self):
         if TRAINING:
             self.forward = self.forward_training
         else:
             self.forward = self.forward_inference

   @jit.script_method
   def forward_inference(self, inputs)
   def forward_training(self, inputs):