Value Error visualizing Network using networkx

Hello all!

I was told that using the NetworkX package was helpful in creating complex graphs, so I decided to go ahead and use NetworkX to visualize the logic behind forward propagation where each layer of the neural network is represented as a subset of nodes. Unfortunately, I ran into this issue:

ValueError: all nodes must have subset_key (default='subset') as data

however I have already stated this in my code. I’m attaching the graph creation of the code below. Any modification or suggestions would be greatly appreciated!

# Create a graph object
G = nx.DiGraph()

# Adding nodes & edges
G.add_nodes_from(['Input Layer'] + ['Hidden Layer ' + str(i) for i in range(1, hidden_size+1)] + ['Output Layer'])

for i in range(input_size):
    G.add_edge('Input Layer', 'Hidden Layer 1', weight=weights_hidden[i][0])
for i in range(hidden_size):
    if i < hidden_size-1:
        G.add_edge('Hidden Layer ' + str(i+1), 'Hidden Layer ' + str(i+2), weight=weights_hidden[:, i+1])
    G.add_edge('Hidden Layer ' + str(i+1), 'Output Layer', weight=weights_output[i])

# Set positions for nodes
pos = nx.multipartite_layout(G, subset_key='subset')
nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=500, alpha=0.8)

edge_labels = {(u, v): round(d['weight'], 2) for u, v, d in G.edges(data=True)}
nx.draw_networkx_edges(G, pos, width=2, alpha=0.8, arrows=True)
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

node_labels = {node: node for node in G.nodes}
nx.draw_networkx_labels(G, pos, node_labels, font_size=10, font_weight='bold')

plt.axis('off')

# plotting
plt.title('Feed-forward Neural Network Architecture')
plt.tight_layout()
plt.show()