Multiple inputs to model?

Hi! For the loss function I am trying to implement, I want my model to take some input image, use a CNN architecture to generate an image, and then process the generated image a bit to create my final network output. I want to calculate the loss and optimize based on this final network output.

The processing I’m doing is: subtracting the input from the generated image, discretizing according to some bins (which I’m currently thinking about using numpy for?? since PyTorch doesn’t have discretizing?), and adding the discretized values to the input image. As far as I can tell, this should be differentiable, which is necessary for a good loss function, right?

A couple questions came up as I thought about this:

  • Where should I incorporate the post-processing? In the model initialization/forward call, or afterwards?
  • If I do it in the model part, I’d need to pass the bin values to the model–is there a recommended way of doing this? I was thinking I could bundle my input image as a tuple, with the bin values, and use that as the input to my forward call?
  • If I do it after the model forward has been called (e.g., after output = model(input)), what is best practice such that loss.backward() and optimizer.step() will work properly?

I know it’s a complicated, specific scenario. Thanks in advance!