I had a similar case. I used:
y = torch.zeros([batch_size, c, h, w]), requires_grad=False)
then I update the value of y according to the value of the network output and then apply a loss function on y and it worked for me.
I had a similar case. I used:
y = torch.zeros([batch_size, c, h, w]), requires_grad=False)
then I update the value of y according to the value of the network output and then apply a loss function on y and it worked for me.