Implementing half-way branch in the network

Question

I would like to implement a network architecture similar to Inception shown below. They are similar in the sense that there are multiple half-way branches (the part within in blue block).

In order to implement this architecture, I could write those HalfWayBranch within MainBranch. However, I would like different modules to be least entangled together.

Then I have

class MainBranch(nn.Module):
# ...

class HalfWayBranch(nn.Module):
# ...

Now I have two questions

  • How do I write my training loop?
  • Could the loss be backproped correctly?

If I understand your use case correctly, you would like to add these auxiliary modules into your model.
I would recommend to write a main module and initialize the base model as well as these aux modules inside the __init__.
Then in the forward you could simply use an intermediate output, feed it to the next module in your base model as well as one of the aux modules, and return all necessary outputs at the end. This would be similar to the Inception implementation.

1 Like