Torch.flatten() causes zero grad

Thank you so much!
by using nn.BCEWithLogitsLoss, I got below results:

tensor(0.0332)
tensor(0.0663)
----------------
tensor(-0.2500)
tensor(-0.5000)
----------------
tensor(-0.2265)
tensor(-0.4531)
----------------
tensor(0.2500)
tensor(0.5000)
----------------
tensor(0.2344)
tensor(0.4688)
----------------
tensor(0.2500)
tensor(0.5000)
----------------
tensor(0.2579)
tensor(0.5156)
----------------
tensor(-0.1406)
tensor(-0.2812)
----------------
tensor(-0.2500)
tensor(-0.5000)
----------------
tensor(-0.2421)
tensor(-0.4844)
----------------
tensor(-0.2657)
tensor(-0.5312)
----------------
tensor(-0.2578)
tensor(-0.5156)
----------------
tensor(0.2578)
tensor(0.5156)
----------------