3D CNN-Pose-estimation

Hi,
I am trying to implement an architecture from a paper and stumbled across the below problem.

note:

  • I am using a 3D CNN architecture for pose estimation of image data(videos split to frames)
  • I am using 17 key points therefore my target heatmap has 17 channels for this problem with
  • MSELoss() as the loss function
  • output-softmax layer
  • Heatmap dims: (1,17, 200, 200)
  • input dims:(1,3,4, 200, 200)

Question:
What does my target heatmap need to contain since, I am using a network that takes in an input image shape: (1,3,4,200,200) and outputs a target heatmap: (1,17, 200, 200)?
My target heatmap can be only of one image but my input is a sequence of frames(*4) from a video?
How can I go about resolving this problem to train my network?

reference:
[2104.08029] T-LEAP: occlusion-robust pose estimation of walking cows using temporal information.