I am trying to implement a graph using PyTorch where each layer contains a single node and connections between layers may skip several layers. As a simple example, the code below attempts to connect 4 nodes (A,B,C and D) with connections between A-B, A-C, B-D and C-D. The code below attempts to do this, but even if it worked, it is not clear to me how I can access the layers to know who is connected to whom:
import torch
import torch.nn as nn
class Net(nn.Module):
def init(self):
super(Net, self).init()
self.A = nn.Linear(1,1)
self.B = nn.Linear(1,1)
self.C = nn.Linear(1,1)
self.D = nn.Linear(1,1)
def forward(self, x):
out = self.B(A,D)
out = self.C(A,D)
out = self.
if name == ‘main’:
my_nn = Net()
print(“I’m a net!”)