RuntimeError: Given transposed=1, weight of size [841, 1024, 4, 4], expected input[841, 1024, 1, 1] to have 841 channels, but got 1024 channels instead

class RetinalVAE(nn.Module):
    def __init__(self):
        super(RetinalVAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv2d(1, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1024, 3, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(1024, 1024, 4, 1, 0, bias=False),
            
            nn.LeakyReLU(0.2, inplace=True),
            
        )

        self.decoder = nn.Sequential(
            
            nn.ConvTranspose2d(841, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(1024, 512, 3, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256,nc, 4, 2, 1, bias=False),
            
            nn.Sigmoid()
            
        )
        self.fc1 = nn.Linear(1024, 1024)
        self.fc21 = nn.Linear(1024, 10)
        self.fc22 = nn.Linear(1024, 10)

        self.fc3 = nn.Linear(10, 1024)
        self.fc4 = nn.Linear(1024, 1024)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()

        # Dir prior
        self.prior_mean, self.prior_var = map(nn.Parameter, prior(10, 0.3))
        self.prior_logvar = nn.Parameter(self.prior_var.log())
        self.prior_mean.requires_grad = False
        self.prior_var.requires_grad = False
        self.prior_logvar.requires_grad = False


    def encode(self, x):
        conv = self.encoder(x);
        h1 = self.fc1(conv.view(-1, 1024))
        return self.fc21(h1), self.fc22(h1)

    def decode(self, gauss_z):
        dir_z = F.softmax(gauss_z,dim=1) 
        h3 = self.relu(self.fc3(dir_z))
        deconv_input = self.fc4(h3)
        deconv_input = deconv_input.view(-1,1024,1,1)
        #deconv_input = deconv_input.view(x.size(0), -1)
        return self.decoder(deconv_input)

    def sampling(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        mu, logvar = self.encode(x)
        gauss_z = self.sampling(mu, logvar)        
        dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
        return self.decode(gauss_z), mu, logvar, gauss_z, dir_z

    # Reconstruction + KL divergence losses s
    def loss_function(self, recon_x, x, mu, logvar, K):
        beta = 1.0
        BCE = F.binary_cross_entropy(recon_x.view(-1, 2048), x.view(-1, 2048), reduction='sum')        
        prior_mean = self.prior_mean.expand_as(mu)
        prior_var = self.prior_var.expand_as(logvar)
        prior_logvar = self.prior_logvar.expand_as(logvar)
        var_division = logvar.exp() / prior_var 
        diff = mu - prior_mean 
        diff_term = diff *diff / prior_var 
        logvar_division = prior_logvar - logvar 
        # KL
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - K)        
        return BCE + KLD


model = RetinalVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        data = data.unsqueeze(0)
        optimizer.zero_grad()
        recon_batch, mu, logvar, gauss_z, dir_z = model(data)
        
        loss = model.loss_function(recon_batch, data, mu, logvar, 10)
        loss = loss.mean()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            data = data.unsqueeze(0)
            recon_batch, mu, logvar, gauss_z, dir_z = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar, 1)
            test_loss += loss.mean()
            test_loss.item()
            if i == 0:
                n = min(data.size(0), 18)
                test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    # Train
    for epoch in enumerate(dataset):
        train(epoch)
        test(epoch)
        ```
```---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-643ea9d0266e> in <module>
    206     # Train
    207     for epoch in enumerate(dataset):
--> 208         train(epoch)
    209         test(epoch)
    210 

<ipython-input-2-643ea9d0266e> in train(epoch)
    170         data = data.unsqueeze(0)
    171         optimizer.zero_grad()
--> 172         recon_batch, mu, logvar, gauss_z, dir_z = model(data)
    173 
    174         loss = model.loss_function(recon_batch, data, mu, logvar, 10)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-2-643ea9d0266e> in forward(self, x)
    142         gauss_z = self.sampling(mu, logvar)
    143         dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
--> 144         return self.decode(gauss_z), mu, logvar, gauss_z, dir_z
    145 
    146     # Reconstruction + KL divergence losses s

<ipython-input-2-643ea9d0266e> in decode(self, gauss_z)
    130         deconv_input = deconv_input.view(-1,1024,1,1)
    131         #deconv_input = deconv_input.view(x.size(0), -1)
--> 132         return self.decoder(deconv_input)
    133 
    134     def sampling(self, mu, logvar):

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input, output_size)
    921             input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore[arg-type]
    922 
--> 923         return F.conv_transpose2d(
    924             input, self.weight, self.bias, self.stride, self.padding,
    925             output_padding, self.groups, self.dilation)

RuntimeError: Given transposed=1, weight of size [841, 1024, 4, 4], expected input[841, 1024, 1, 1] to have 841 channels, but got 1024 channels instead

Your self.decoder expects 841 input channels as it’s first layer is defined as:

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(841, 1024, 4, 1, 0, bias=False),

while you are explicitly reshaping the input to self.decoder to have 1024 channels:

        deconv_input = deconv_input.view(-1,1024,1,1)
        return self.decoder(deconv_input)

which is then failing with the shape mismatch.

Besides that, I would not recommend to use -1 in the batch dimension as it usually yields to shape mismatch errors later in the training if you are not careful in calculating the feature dimensions.
Use x = x.view(x.size(0), ...) instead.
Also note that you are applying F.softmax multiple times on the same tensor.

Hello After making the corrections as you recommended this what I got

RuntimeError                              Traceback (most recent call last)
<ipython-input-1-e3146389b929> in <module>
    206     # Train
    207     for epoch in enumerate(dataset):
--> 208         train(epoch)
    209         test(epoch)
    210 

<ipython-input-1-e3146389b929> in train(epoch)
    170         data = data.unsqueeze(0)
    171         optimizer.zero_grad()
--> 172         recon_batch, mu, logvar, gauss_z, dir_z = model(data)
    173 
    174         loss = model.loss_function(recon_batch, data, mu, logvar, 10)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-1-e3146389b929> in forward(self, x)
    142         gauss_z = self.sampling(mu, logvar)
    143         dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
--> 144         return self.decode(gauss_z), mu, logvar, gauss_z, dir_z
    145 
    146     # Reconstruction + KL divergence losses s

<ipython-input-1-e3146389b929> in decode(self, gauss_z)
    130         #deconv_input = deconv_input.view(-1,1024,1,1)
    131         deconv_input = deconv_input.view(deconv_input.size(0), -1)
--> 132         return self.decoder(deconv_input)
    133 
    134     def sampling(self, mu, logvar):

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input, output_size)
    921             input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore[arg-type]
    922 
--> 923         return F.conv_transpose2d(
    924             input, self.weight, self.bias, self.stride, self.padding,
    925             output_padding, self.groups, self.dilation)

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [841, 1024, 4, 4], but got 2-dimensional input of size [841, 1024] instead

You are now flattening the activation to a 2D tensor, which is also wrong as 4 dimensions are expected:

deconv_input = deconv_input.view(deconv_input.size(0), -1)
return self.decoder(deconv_input)

Add the spatial sizes e.g. as: view(x.size(0), 841, 1, 1) and it would work assuming the number of elements is correct for the desired shape.

After still using x = x.view(x.size(0),....) still got the same results

class RetinalVAE(nn.Module):
    def __init__(self):
        super(RetinalVAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv2d(1, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1024, 3, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(1024, 1024, 4, 1, 0, bias=False),
            
            nn.LeakyReLU(0.2, inplace=True),
            
        )

        self.decoder = nn.Sequential(
            
            nn.ConvTranspose2d(841, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(1024, 512, 3, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256,nc, 4, 2, 1, bias=False),
            
            nn.Sigmoid()
            
        )
        self.fc1 = nn.Linear(1024, 1024)
        self.fc21 = nn.Linear(1024, 10)
        self.fc22 = nn.Linear(1024, 10)

        self.fc3 = nn.Linear(10, 1024)
        self.fc4 = nn.Linear(1024, 1024)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()

        # Dir prior
        self.prior_mean, self.prior_var = map(nn.Parameter, prior(10, 0.3))
        self.prior_logvar = nn.Parameter(self.prior_var.log())
        self.prior_mean.requires_grad = False
        self.prior_var.requires_grad = False
        self.prior_logvar.requires_grad = False


    def encode(self, x):
        conv = self.encoder(x);
        h1 = self.fc1(conv.view(-1, 1024))
        return self.fc21(h1), self.fc22(h1)

    def decode(self, gauss_z):
        dir_z = F.softmax(gauss_z,dim=1) 
        h3 = self.relu(self.fc3(dir_z))
        deconv_input = self.fc4(h3)
        #deconv_input = deconv_input.view(-1,1024,1,1)
        deconv_input = deconv_input.view(deconv_input.size(0), 1024,1,1)
        return self.decoder(deconv_input)

    def sampling(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        mu, logvar = self.encode(x)
        gauss_z = self.sampling(mu, logvar)        
        dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
        return self.decode(gauss_z), mu, logvar, gauss_z, dir_z

    # Reconstruction + KL divergence losses s
    def loss_function(self, recon_x, x, mu, logvar, K):
        beta = 1.0
        BCE = F.binary_cross_entropy(recon_x.view(-1, 2048), x.view(-1, 2048), reduction='sum')        
        prior_mean = self.prior_mean.expand_as(mu)
        prior_var = self.prior_var.expand_as(logvar)
        prior_logvar = self.prior_logvar.expand_as(logvar)
        var_division = logvar.exp() / prior_var 
        diff = mu - prior_mean 
        diff_term = diff *diff / prior_var 
        logvar_division = prior_logvar - logvar 
        # KL
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - K)        
        return BCE + KLD


model = RetinalVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        data = data.unsqueeze(0)
        optimizer.zero_grad()
        recon_batch, mu, logvar, gauss_z, dir_z = model(data)
        
        loss = model.loss_function(recon_batch, data, mu, logvar, 10)
        loss = loss.mean()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            data = data.unsqueeze(0)
            recon_batch, mu, logvar, gauss_z, dir_z = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar, 1)
            test_loss += loss.mean()
            test_loss.item()
            if i == 0:
                n = min(data.size(0), 18)
                test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    # Train
    for epoch in enumerate(dataset):
        train(epoch)
        test(epoch)
        ```
```---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-643ea9d0266e> in <module>
    206     # Train
    207     for epoch in enumerate(dataset):
--> 208         train(epoch)
    209         test(epoch)
    210 

<ipython-input-2-643ea9d0266e> in train(epoch)
    170         data = data.unsqueeze(0)
    171         optimizer.zero_grad()
--> 172         recon_batch, mu, logvar, gauss_z, dir_z = model(data)
    173 
    174         loss = model.loss_function(recon_batch, data, mu, logvar, 10)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-2-643ea9d0266e> in forward(self, x)
    142         gauss_z = self.sampling(mu, logvar)
    143         dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
--> 144         return self.decode(gauss_z), mu, logvar, gauss_z, dir_z
    145 
    146     # Reconstruction + KL divergence losses s

<ipython-input-2-643ea9d0266e> in decode(self, gauss_z)
    130         deconv_input = deconv_input.view(-1,1024,1,1)
    131         #deconv_input = deconv_input.view(x.size(0), -1)
--> 132         return self.decoder(deconv_input)
    133 
    134     def sampling(self, mu, logvar):

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input, output_size)
    921             input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore[arg-type]
    922 
--> 923         return F.conv_transpose2d(
    924             input, self.weight, self.bias, self.stride, self.padding,
    925             output_padding, self.groups, self.dilation)

RuntimeError: Given transposed=1, weight of size [841, 1024, 4, 4], expected input[841, 1024, 1, 1] to have 841 channels, but got 1024 channels instead

[/quote]

You have not chanced the code and are still using the same code from the first post as the error message indicates:

<ipython-input-2-643ea9d0266e> in decode(self, gauss_z)
    130         deconv_input = deconv_input.view(-1,1024,1,1)
    131         #deconv_input = deconv_input.view(x.size(0), -1)
--> 132         return self.decoder(deconv_input)

sorry I think I made a wrong post