How to segment future a frame with single channel using past multiple input frames with multiple channels satellite image?

Hi all!

I’m relatively new to PyTorch and I want to segment the next frame of image based on multiple input frames with multiple channels satellite image. i.e. input shape (batch, time_frames, satellite_channels, height, width), such as time_frames is 5, satellite_channels is 7, and output shape (batch, 1, 1, height, width).

I want to achieve the task using unet model, such as unet3d, but I’m not sure how to represent the data in a way that makes sense in PyTorch. Is there any known way to do this?

I’m considering whether to use the satellite channel dimension as depth or the time frame as depth in Conv3d. If the time frame is used as depth in Conv3d, how to ensure that the depth size of output is 1? Is it using the kernel size parameter?

Thank you in advance for any helpful