I’m curious to know the best method for creating a “branch” from a linear layer.
For example… If I have:
root_node = torch.nn.Linear( 1, 2)
the node takes 1 input and produces 2 outputs. Can I access each of those outputs independently to create a branch like structure? Something like:
branch_1 = torch.nn.Linear( 1, 2)
branch_2 = torch.nn.Linear( 1, 2)
X = root_node( input_data)
b1 = branch_1( X[0].view(-1,1))
b2 = branch_2( X[1].view(-1,1))
or something to that effect?
This doesn’t result in an error, but I’m not sure if I’d be thoroughly messing up something with the gradient.