RuntimeError: Given groups=1, weight of size 64 1 3 3, expected input[1, 4, 224, 224] to have 1 channels, but got 4 channels instead

Now that issue is solved you were right!
But I do not understand y does it print nan at

iteration