ValueError: Using a target size (torch.Size([1, 1, 256, 256])) that is different to the input size (torch.Size([4, 1, 148, 148])) is deprecated. Please ensure they have the same size

def prior(K, alpha):
    a = torch.Tensor(1, K).float().fill_(alpha)
    mean = a.log().t() - a.log().mean(1)
    var = ((1 - 2.0 / K) * a.reciprocal()).t() + (1.0 / K ** 2) * a.reciprocal().sum(1)
    return mean.t(), var.t() # Parameters of prior distribution after approximation

class RetinalVAE(nn.Module):
    def __init__(self):
        super(RetinalVAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv2d(1, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1024, 3, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(1024, 1024, 4, 1, 0, bias=False),
            
            nn.LeakyReLU(0.2, inplace=True),
            
        )

        self.decoder = nn.Sequential(
            
            nn.ConvTranspose2d(841, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(1024, 512, 3, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256,nc, 4, 2, 1, bias=False),
            
            nn.Sigmoid()
            
        )
        self.fc1 = nn.Linear(1024, 1024)
        self.fc21 = nn.Linear(1024, 10)
        self.fc22 = nn.Linear(1024, 10)

        self.fc3 = nn.Linear(10, 1024)
        self.fc4 = nn.Linear(1024, 1024)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()

        # Dir prior
        self.prior_mean, self.prior_var = map(nn.Parameter, prior(10, 0.3))
        self.prior_logvar = nn.Parameter(self.prior_var.log())
        self.prior_mean.requires_grad = False
        self.prior_var.requires_grad = False
        self.prior_logvar.requires_grad = False


    def encode(self, x):
        conv = self.encoder(x);
        h1 = self.fc1(conv.view(-1, 1024))
        return self.fc21(h1), self.fc22(h1)

    def decode(self, gauss_z):
        dir_z = F.softmax(gauss_z,dim=1) 
        h3 = self.relu(self.fc3(dir_z))
        deconv_input = self.fc4(h3)
        #deconv_input = deconv_input.view(-1,1024,1,1)
        deconv_input = deconv_input.view(4,841,16,16)
        #deconv_input = deconv_input.view(deconv_input.size(0), -1)
        return self.decoder(deconv_input)

    def sampling(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        mu, logvar = self.encode(x)
        gauss_z = self.sampling(mu, logvar)        
        dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
        return self.decode(gauss_z), mu, logvar, gauss_z, dir_z

    # Reconstruction + KL divergence losses s
    def loss_function(self, recon_x, x, mu, logvar, K):
        beta = 1.0
        #BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')  
        BCE = F.binary_cross_entropy(recon_x.view(4,1,148,148), x.view(1,1,256,256), reduction='sum')   
        prior_mean = self.prior_mean.expand_as(mu)
        prior_var = self.prior_var.expand_as(logvar)
        prior_logvar = self.prior_logvar.expand_as(logvar)
        var_division = logvar.exp() / prior_var 
        diff = mu - prior_mean 
        diff_term = diff *diff / prior_var 
        logvar_division = prior_logvar - logvar 
        # KL
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - K)        
        return BCE + KLD


model = RetinalVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        data = data.unsqueeze(0)
        optimizer.zero_grad()
        recon_batch, mu, logvar, gauss_z, dir_z = model(data)
        
        loss = model.loss_function(recon_batch, data, mu, logvar, 10)
        loss = loss.mean()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            data = data.unsqueeze(0)
            recon_batch, mu, logvar, gauss_z, dir_z = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar, 1)
            test_loss += loss.mean()
            test_loss.item()
            if i == 0:
                n = min(data.size(0), 18)
                test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    # Train
    for epoch in enumerate(dataset):
        train(epoch)
        test(epoch)
ValueError                                Traceback (most recent call last)
<ipython-input-10-f3fe9208c6cf> in <module>
    208     # Train
    209     for epoch in enumerate(dataset):
--> 210         train(epoch)
    211         test(epoch)
    212 

<ipython-input-10-f3fe9208c6cf> in train(epoch)
    174         recon_batch, mu, logvar, gauss_z, dir_z = model(data)
    175 
--> 176         loss = model.loss_function(recon_batch, data, mu, logvar, 10)
    177         loss = loss.mean()
    178         loss.backward()

<ipython-input-10-f3fe9208c6cf> in loss_function(self, recon_x, x, mu, logvar, K)
    149         beta = 1.0
    150         #BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
--> 151         BCE = F.binary_cross_entropy(recon_x.view(4,1,148,148), x.view(1,1,256,256), reduction='sum')
    152         prior_mean = self.prior_mean.expand_as(mu)
    153         prior_var = self.prior_var.expand_as(logvar)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2904         reduction_enum = _Reduction.get_enum(reduction)
   2905     if target.size() != input.size():
-> 2906         raise ValueError(
   2907             "Using a target size ({}) that is different to the input size ({}) is deprecated. "
   2908             "Please ensure they have the same size.".format(target.size(), input.size())

ValueError: Using a target size (torch.Size([1, 1, 256, 256])) that is different to the input size (torch.Size([4, 1, 148, 148])) is deprecated. Please ensure they have the same size.