How to Jointly Train Models that Split? [Advanced Computer Vision Paper]

I’m interested in implementing the Aux loss model described in this paper. I have attached an image of the model architecture below.

A small description of the model and how it works:
This model is used to classify X-ray images. PA and L are two different views of an X-ray for one patient. Specifically, PA is an X-ray viewed from the back and L the same X-ray but from a side view.

The paper leverages DenseNet121 pretrained models (4 dense blocks) for this architecture. In the image below, a circle represents a dense block.

Here’s the logic/control flow of the model:

  • If the L image is present but the PA image is missing for a patient, the data should only flow through the L branch (i.e. the DenseNet121 model for L).
  • If the PA image is present but the L image is missing for a patient, the data should only flow through the PA branch (i.e. the DenseNet121 model for PA).
  • If both images are present, the data should flow through both branches up to the 3rd dense block. After that, the outputs of each of the 3rd dense blocks should be concatenated and passed through another model (assume it’s linear).

I understand how I could implement and train this model jointly for a batch size of 1. However, the paper claims that they can train this model jointly for a batch size of 8. How could I implement this efficiently in PyTorch for a batch size of 8?

What do you mean by efficiently? It’s a matter of setting if statements in the forward pass. You “may” optimize it by using two optimizers to run backward only when needed.

Here’s some pseudo code for the forward pass I have in my model so far:

def forward(self, x):
   if x has PA view but not L view:
        go through PA classifier
   if x has L view but not PA view:
        go through L classifier
   else:
       output1 = run through the first 3 blocks of of PA classifier
       output2 = run through second 3 blocks of L classifier
       concatenate output1 and output2 and feed them through Linear Layer

   return output

Is the forward pass designed to assume a batch size of 1?

If not, wouldn’t I need to do this:

def forward(self, x):
  for example in x:
      if x has PA view but not L view:
         .....

       return output

By-default Pytorch nn.Modules requires the first dimension to be the batch.
In case of an RGB image a single sample is of the size [R,G,B]xW,H. The input by default requires an additional dimension so that it becomes Nx3xWxH.
In general any model is designed this way. Increasing your batch size is nothing but making N=8.
This concatanation step is automatically done by the dataloader. If you are using pytorch dataloader you just need to change the batch_size argument.

WRT optimizers, optimizers iterate over model parameters you passed.
If you use a single optimizer it will iterate over everything, it will detect that gradients of non-used branch are none and that’s all. But it still uses time to iterate over all the parameters.

Using several optimizers will iterate over the which you actually know contain gradients. I don’t really know how many time you can save by doing this. Indeed, you would need to extend the logic of the forward pass.

So the normal way would be:

optimizer = OPTIM(model.parameters())
optimizer.zero_grad()
output=model(input)
loss = criterion(output,gt)
loss.backward()
optimizer.step()

It would become something like:

optimizer1 = OPTIM(branch1.parameters())
optimizer2 = OPTIM(branch2.parameters())
output=model(input)
loss = criterion(output,gt)
if x has PA view but not L view:
    optimizer1.zero_grad()
    loss.backward()
    optimizer1.step()
elif x has PA view but not L view:
    optimizer2.zero_grad()
    loss.backward()
    optimizer2.step()

But i don’t know how many time would you save.
If you find it confusing just don’t do it.