I have an architecture that’s composed of 2 components.
- N encoders
- N decoders
def __init__(self, N, **kwargs):
super(MODEL, self).__init__()
encoders = [Encoder(**kwargs)] * N
decoders = [Decoder(**kwargs)] * N
def forward(self, *input):
pass
I’m not sure how do I call each (encoder/decoder)'s forward
method properly (each will take input
) .
Knowing that I also use a custom loss function (will I need to override backward
as well?).
Also, am I constructing the encoders properly? is there a better/convenient way?