Your explanation is right in general.
Just some minor issues:
In PyTorch, images are represented as [channels, height, width]
, so a color image would be [3, 256, 256]
.
During the training you will get batches of images, so your shape in the forward method will get an additional batch dimension at dim0: [batch_size, channels, height, width]
.
The .view()
operation gives you a new view on the tensor without copying any data.
This means view
is cheap to call and you don’t think too much about potential performance issues.
In your example, x
is reshaped into a tensor
of shape 16*5*5
in dim1 and all remaining values in dim0. That’s what the -1
stands for.
In the case of a 256 -channel image, the view
operation would resume the same, since the spatial dimensions of the image are the same and the channels are defined by the number of kernels of conv2
.