ValueError: Expected input batch_size (324) to match target batch_size (4)

Yes, the output of print(x.shape) is as follow:
torch.Size([4, 64, 9, 9])