Predictions and targets are expected to have the same shape, but got torch.Size([2, 3, 1, 1, 1]) and torch.Size([2, 3, 127, 127, 127])

I am getting an error.

RuntimeError: Predictions and targets are expected to have the same shape, but got torch.Size([2, 3, 1, 1, 1]) and torch.Size([2, 3, 127, 127, 127]).

Below is my code for convolution Autoencorder for 3D images

# Define the architecture of the 3D convolutional autoencoder
class Convo3DAE(nn.Module):
    def __init__(self):
        super(Convo3DAE, self).__init__()
        # Encoder
        self.encoder= nn.Sequential()
        self.encoder.add_module('C1', nn.Conv3d(in_channels=3, out_channels=16, kernel_size=3,stride=2 ))
        self.encoder.add_module('Batch Norma 3d', nn.BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
        self.encoder.add_module('relu1', nn.ReLU(True))
        self.encoder.add_module('C2', nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3,stride=2 ))
        self.encoder.add_module('relu2', nn.ReLU(True))
        self.encoder.add_module('C3',nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3,stride=2 ))
        self.encoder.add_module('relu3', nn.ReLU(True))
        self.encoder.add_module('C4',nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3,stride=2 ))


        # Decoder
        self.decoder= nn.Sequential()
        self.decoder.add_module('tC1',nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=3,stride=2 ))
        self.decoder.add_module('Batch Norma 3d', nn.BatchNorm3d(num_features = 64))
        self.decoder.add_module('relu1', nn.ReLU(True))
        self.decoder.add_module('tC2', nn.Conv3d(in_channels=64, out_channels=32, kernel_size=3,stride=2 ))
        self.decoder.add_module('relu2', nn.ReLU(True))
        self.decoder.add_module('tC3', nn.Conv3d(in_channels=32, out_channels=16, kernel_size=3,stride=2 ))
        self.decoder.add_module('relu3', nn.ReLU(True))
        self.decoder.add_module('C2', nn.Conv3d(in_channels=16, out_channels=3, kernel_size=3,stride=2 ))




    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

This is the Data Loader



def load_img(img_dir, img_list):
    images=[]
    for i, image_name in enumerate(img_list):
        if (image_name.split('.')[1] == 'npy'):

            image = np.load(img_dir+image_name)

            images.append(image)
    images = np.array(images)

    return(images)
def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):
    # setting Length to 100
    L = len(img_list)

    #keras needs the generator infinite, so we will use while true
    while True:

        batch_start = 0
        batch_end = batch_size

        while batch_start < L:
            limit = min(batch_end, L)

            X = load_img(img_dir, img_list[batch_start:limit])
            Y = load_img(mask_dir, mask_list[batch_start:limit])

            yield (X,Y) #a tuple with two numpy arrays with batch_size samples

            batch_start += batch_size
            batch_end += batch_size
    if stop_flag:  # Stop generator after completing one pass through the dataset
        break

** Training and Validation function


def train_batch(data1, model, criterion, optimizer):
    model.train()
    data = data1
    optimizer.zero_grad()
    output = model(data.float())
    print(data.shape)
    print(output.shape)
    
    loss = criterion(output, data.float())
    loss.backward()
    optimizer.step()
    return loss

with torch.inference_mode():
  def eval_batch(data, model, criterion):
      model.eval()

      loss = criterion(output, data)
      output = model(data.float())
      loss = criterion(output, data.float())
      return loss.item()

This is the Training Loop

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

n_epochs = 100

training_loss, test_loss = [], []

for epoch in range(n_epochs):
    training_losses, test_losses = [], []
    print(epoch)
    i=0

    for data, _ in train_img_datagen:

        if type(data) is np.ndarray:
            data = torch.from_numpy(data)
        
        x = data.permute(0, 4, 1, 2, 3).contiguous()
        x = nn.functional.interpolate(x, size=(127, 127, 127), mode='trilinear', align_corners=False)

        print(x.shape)
        i=i+1
        trng_batch_loss = train_batch(x, model, criterion, optimizer)
        training_losses.append(trng_batch_loss.item())
    training_per_epoch_loss = np.array(training_losses).mean()

    for data, _ in val_img_datagen:
#        x = data.permute(0, 4, 1, 2, 3).contiguous()
#        x = nn.functional.interpolate(x, size=(127, 127, 127), mode='trilinear', align_corners=False)
        tst_batch_loss = eval_batch(x, model, criterion)
        test_losses.append(tst_batch_loss.item())
    test_per_epoch_loss = np.array(test_losses).mean()

    training_loss.append(training_per_epoch_loss)
    test_loss.append(test_per_epoch_loss)
    
    if (epoch+1) % 10==0:
        print(f'Epoch: {epoch+1}/{n_epochs}\t| Training loss: {training_per_epoch_loss:.4f} |   ', end='')
        print(f'Test loss: {test_per_epoch_loss:.4f}')
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
           }, f'checkpoint{epoch}.pt')

Below is the error

torch.Size([2, 3, 127, 127, 127])
torch.Size([2, 3, 127, 127, 127])
torch.Size([2, 3, 1, 1, 1])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-63-dcc34f8a85e0> in <cell line: 9>()
     22         print(x.shape)
     23         i=i+1
---> 24         trng_batch_loss = train_batch(x, model, criterion, optimizer)
     25         training_losses.append(trng_batch_loss.item())
     26     training_per_epoch_loss = np.array(training_losses).mean()

9 frames
<ipython-input-61-c58836ad8f74> in train_batch(data1, model, criterion, optimizer)
      6     print(output.shape)
      7     print(data.shape)
----> 8     loss = criterion(output, data.float())
      9     loss.backward()
     10     optimizer.step()

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py in forward(self, *args, **kwargs)
    301             self._forward_cache = self._forward_full_state_update(*args, **kwargs)
    302         else:
--> 303             self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
    304 
    305         return self._forward_cache

/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py in _forward_reduce_state_update(self, *args, **kwargs)
    370 
    371         # calculate batch state and compute batch value
--> 372         self.update(*args, **kwargs)
    373         batch_val = self.compute()
    374 

/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py in wrapped_func(*args, **kwargs)
    473                             " device corresponds to the device of the input."
    474                         ) from err
--> 475                     raise err
    476 
    477             if self.compute_on_cpu:

/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py in wrapped_func(*args, **kwargs)
    463             with torch.set_grad_enabled(self._enable_grad):
    464                 try:
--> 465                     update(*args, **kwargs)
    466                 except RuntimeError as err:
    467                     if "Expected all tensors to be on" in str(err):

/usr/local/lib/python3.10/dist-packages/torchmetrics/regression/mse.py in update(self, preds, target)
     99     def update(self, preds: Tensor, target: Tensor) -> None:
    100         """Update state with predictions and targets."""
--> 101         sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs=self.num_outputs)
    102 
    103         self.sum_squared_error += sum_squared_error

/usr/local/lib/python3.10/dist-packages/torchmetrics/functional/regression/mse.py in _mean_squared_error_update(preds, target, num_outputs)
     31 
     32     """
---> 33     _check_same_shape(preds, target)
     34     if num_outputs == 1:
     35         preds = preds.view(-1)

/usr/local/lib/python3.10/dist-packages/torchmetrics/utilities/checks.py in _check_same_shape(preds, target)
     40     """Check that predictions and target have the same shape, else raise error."""
     41     if preds.shape != target.shape:
---> 42         raise RuntimeError(
     43             f"Predictions and targets are expected to have the same shape, but got {preds.shape} and {target.shape}."
     44         )

RuntimeError: Predictions and targets are expected to have the same shape, but got torch.Size([2, 3, 1, 1, 1]) and torch.Size([2, 3, 127, 127, 127]).

Can you tell me why my predition is of torch.Size([2, 3, 1, 1, 1]) insted of torch.Size([2, 3, 127, 127, 127]). I have provided model and all necessary code.

Thank you

Your model reduces the spatial size of the input as it’s using conv layers with a stride of 2.
Here is a simple example:

conv = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=2)
x = torch.randn(1, 1, 24, 24, 24)
out = conv(x)
print(out.shape)
# torch.Size([1, 1, 11, 11, 11])