Getting Info on Connections Between Layers in Models with Split Connections

I am clear on how to get various layer information via iterating:

for child, name in zip(model.children(), model.named_children()):
print(child, name, type(child), name[0])


And I am clear on how to get the number of parameters and their values via:

for layer in model.parameters():

In a sequential model, the connection sizes will be equal to the layer parameters in their corresponding connected dimension.

But what I am not clear on is how to find the name and sizes of the next connected layers in a non-sequential model.

For instance, suppose we have the following dummy model with a split connection:

import torch
import torch.nn as nn

class Split_Model(nn.Module):
    def __init__(self):
        super(Split_Model, self).__init__()
        self.fc1 = nn.Linear(1, 20)
        self.fc2a = nn.Linear(10, 10)
        self.fc2b = nn.Linear(10, 10)
        self.fc3 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.fc1(x)
        y, x = x[:, :10], x[:, 10:20]
        x = self.fc2a(x)
        y = self.fc2b(y)
        x =[x, y], dim=1)
        x = self.fc3(x)
        return x

model= Split_Model()

Split connections are becoming more useful as they better interconnect a model and eliminate vanishing gradients in very long models. How would I be able to find the next layers in the graph that fc1 are connected to? Let’s say I wanted something like:

>>> fc2a: 10 , fc2b: 10 #something like this is what I would ideally like  for an output

How would I get this information out of a defined model with split connections?

Thank you!