Sparsely connected graph that skips layers

I am trying to implement a graph using PyTorch where each layer contains a single node and connections between layers may skip several layers. As a simple example, the code below attempts to connect 4 nodes (A,B,C and D) with connections between A-B, A-C, B-D and C-D. The code below attempts to do this, but even if it worked, it is not clear to me how I can access the layers to know who is connected to whom:

import torch
import torch.nn as nn

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.A = nn.Linear(1,1)
self.B = nn.Linear(1,1)
self.C = nn.Linear(1,1)
self.D = nn.Linear(1,1)

def forward(self, x):
    out = self.B(A,D)
    out = self.C(A,D)
    out = self.

if name == ‘main’:
my_nn = Net()
print(“I’m a net!”)

I don’t completely understand the use case and the current code snippet seems to try to pass modules to the forward method of the linear layers?
Could you explain your use case a bit more and how these connections should look like?
If you want to pass the output of one layer to specific others (to create the “connections”), you should be able to do so:

def forward(self, x):
    x = self.A(x)
    xb = self.B(x)
    xc = self.C(x)