Hi there,
please wrap your code with preformatted text ,
I think , your problem is given in your input i.e [3,256,256,1]
,
to clear, Pytorch uses [batch, channel, height, width]
, so you need to change shape in that way>
Here i found detail about this, https://discuss.pytorch.org/t/dimensions-of-an-input-image/19439/2?u=ptrblk.
Hope this works