Model initialization problem

Hello everyone,

I have a question about pytorch implementation.

ref

As above figure, it says that the model A(Unet)'s inputs are Ek(in shape of BxHxW) and the previous three reconstruction images(each reconstruction image is in shape 1xHxW). So the model’s input is in shape of (B+3)xHxW in time k.

My question is how to handle the initialization problem in time 0,1,2(i.e. k=0,1,2)?