How to define arbitrary connections in the forward function when defining a model?

Hi everyone, so as we know, when we define a model class, in the init function we write the layer definition, and in the forward function, we specify the connections between the layers as well as getting the output. Example:

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(100, 200)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(200, 10)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

However, this is usually fixed, meaning that the connections between layers are known beforehand. If, for example, I have a dictionary called connection_dict, which stores the connections between layers, such as 1->2, 1->3, 2->4, etc., and another list called layer_encoding_list, which stores the layer definition, and I want to create a model based on the connections from these two variables. Then it would be written as follows:

class MyDLModel(torch.nn.Module):

    def __init__(self, connection_dict, layer_encoding_list):
        super(MyDLModel, self).__init__()

        self.connection_dict = connection_dict
        self.layer_encoding_list = layer_encoding_list
		
		self.layer_list = nn.ModuleList() 
		# Define the layers according to layer_encoding_list
		...
		
    def forward(self, x):
        # Define the connections according to self.connection_dict
		...

As you can see, the layer definitions can be done using nn.ModuleList(), however I’m not clear how to define things in the forward function so that there won’t be any problem with backprop. Should I put it into nn.ModuleList() ? Or a nn.ModuleList() of nn.ModuleList()? And I have seen some examples where people put the outputs from layers to a list, then do torch.cat on it and it still works, however I’m not too sure on the details, since here connections can be arbitrary. Please help me, thank you very much.

As long as you are properly registering modules and parameters, Autograd won’t have a problem. E.g. use nn.ModuleList to create a list of modules instead of a plain Python list.

This is a valid approach and will keep the gradient history. To double check if tensors were detached by a specific operation, you could check the .grad_fn attribute of intermediate tensors before and after the operation in question. If both are showing a valid function, the tensor was not detached.