Recently, I want to use PyTorch to build a slightly special network with the following structure:
In the above network, the green boxes correspond to known variables, i.e., input vector x and m label vectors y^(1), …, y^(m). Here each y^(i) is a one-hot vector with value (1,0) or (0,1). And the remaining gray boxes indicate hidden representations (e.g., z^(1,1) and z^(2,m)).
My question is for a given positive integer
m>1, how to implement such a network in
nn.Module class? I guess maybe I need to use
nn.ModuleList() and the
Could anyone please give a further suggestion or comment? Thanks in advance.