Same UNet architecture definition but very different results

Hi,
I have written a UNet model for image segmentation.
However, it does not seem to learn anything.
Then I looked into another implementation of UNet, which does work.
I have then changed my UNet model to have the same architecture, but it still does not seem to
learn anything, i.e., two UNet implementations with the same architecture, but different outcomes (everything else being the same). Why does it not work with my UNet definition? Moreover, why does the training of my model take around 3 times more time compared to the training of the other model?
Below is the code with a toy image, a visualization of the predicted masks is given at the end.

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import OrderedDict

img = np.array([[ 25,  28,  20,  19,  94,  97, 104,  82, 116, 106, 109, 106, 105,  96,  86,  88,  83,  99,  91, 101, 109,  93,  86,  94,  89,  97, 110,  99,  96, 101,  93,  88],
                [ 27,  23,  27, 111, 116, 107, 108, 107, 104, 145, 114, 112,  94,  93, 106,  91,  98, 114, 108, 107, 106,  96,  98,  99,  95,  99, 100, 107,  98,  95,  91,  99],
                [ 24,  18, 105, 111, 105, 111, 103, 107, 114, 110, 129, 111, 116, 107, 108, 104, 118, 118, 121, 109, 107, 105, 101, 104,  99,  99,  84,  90,  94, 103,  93, 105],  
                [ 17,  78, 101,  99, 109, 114, 106, 104, 112,  93,  99,  95,  96,  93, 109, 106,  96, 106, 103, 113, 100,  95,  89,  86,  93,  93,  95,  93, 103,  94,  89,  98],
                [ 54, 101, 100, 100, 106, 105,  99,  99, 103,  99, 101,  84,  91,  91, 101, 104,  97, 111, 101, 123, 109,  90,  99,  91,  99,  90,  93,  97,  93,  93,  93, 114],
                [105,  96,  99,  98,  99,  94,  95,  96,  95,  88,  96, 104,  95, 119, 111, 100,  95,  98,  97, 103, 101,  91,  92,  94,  93,  93,  93,  95,  94,  93, 101, 104],
                [101, 100, 105,  96,  96,  91, 106,  93,  96,  90,  96,  92, 124, 112, 102,  96,  93,  96,  93, 101, 100,  94,  91,  87,  96,  96,  95,  88,  96, 101, 104, 109],
                [109, 101,  94,  92,  95, 104,  96, 101,  94,  98,  95,  97,  95, 108,  96,  99,  96,  99, 105, 109,  97,  96,  96,  93,  96,  95,  89,  96, 102, 110, 109, 103],
                [ 94,  98,  99,  92, 101,  99,  97, 109, 114, 105, 105,  93,  96, 101,  95,  99, 104,  96, 104,  94, 111, 105,  96,  92,  99,  89,  96, 104, 108, 107, 114, 122],
                [105,  96,  95,  97, 106, 111, 106, 120, 109,  97,  95,  99, 101, 106, 112,  95,  98, 108, 124, 105, 112,  96,  99,  94, 104,  95,  95, 111,  99, 114, 114, 105],
                [ 93,  96,  93, 100, 109, 104, 102, 133,  86, 105, 132,  82,  95, 115, 104, 108, 102, 101, 115, 159, 103, 111, 107,  98, 109, 101,  98, 125,  95, 108, 105, 117],
                [110, 100, 105, 110, 104,  96, 114, 117, 116, 111, 108,  96, 109,  99,  99, 109, 130,  95,  99,  84,  91,  93, 111, 106, 115,  99, 110, 106, 103, 102,  99, 116],
                [106, 107, 111, 109, 106, 116, 108, 119, 113,  94,  91,  91, 111,  94, 104, 110,  99,  89,  90,  91,  98, 108, 106, 109, 107, 115, 112, 119, 127, 105, 106, 106],
                [107, 113,  99, 114, 115, 114, 110, 117,  96,  97,  98, 104,  96,  95, 104, 113,  91,  96,  88,  88,  99,  91, 165, 136, 114, 117, 109, 114, 122, 115, 109, 100],
                [101, 112,  96,  96, 114, 107, 111, 142,  97, 119, 103,  97,  92,  94,  95, 105,  95,  93,  96,  87, 102,  91,  93, 103, 101, 109, 132, 114, 111, 109, 111,  88],
                [ 96, 105, 110, 114,  93, 101, 111, 124, 106, 104, 104,  92, 102,  91,  96,  96,  81,  93, 107, 104,  95,  88,  99,  98,  98, 116, 112, 106, 111, 111, 120,  99],
                [102,  91, 104, 109, 103,  94, 111, 111, 114,  88,  95,  93,  97,  95, 111, 137, 106,  99, 123, 111,  95, 100,  87,  97, 102, 104, 120, 109, 103, 117,  98,  90],
                [149, 110, 113,  99,  99, 107, 101, 115, 109, 120, 106, 109, 103, 101, 134, 147, 124, 127,  89,  97, 107,  93,  98, 101,  82,  99, 129, 109, 111, 109,  93, 135],
                [ 25,  19,  35,  61,  84,  95,  97,  91, 106, 114, 127, 123, 109, 101, 141, 136, 138,  96,  96,  93,  93,  96,  97,  95, 100,  89, 110, 111, 111,  95,  86,  87],
                [ 32,  31,  27,  29,  23,  28,  26,  24,  22,  51,  86, 115, 106, 111, 120, 199, 175, 109, 150,  99,  95,  96,  96,  96,  91,  96,  96,  96, 103,  98, 106, 108],
                [ 21,  35,  35,  36,  32,  33,  37,  37,  26,  23,  23,  24,  32,  76,  35,  47,  44,  28,  28,  22,  20,  20,  22,  22,  20,  19,  19,  24,  24,  24,  35,  70],
                [ 27,  32,  28,  27,  28,  29,  36,  41,  39,  40,  27,  30,  23,  24,  31,  33,  32,  28,  24,  24,  25,  22,  26,  24,  23,  23,  26,  23,  24,  27,  25,  70],
                [ 23,  27,  24,  29,  27,  27,  26,  29,  31,  31,  29,  37,  35,  37,  29,  37,  32,  26,  29,  37,  31,  36,  32,  27,  27,  22,  24,  25,  24,  29,  24,  68],
                [ 24,  22,  24,  20,  24,  24,  24,  24,  23,  26,  28,  27,  26,  37,  34,  28,  27,  32,  26,  26,  27,  28,  26,  24,  22,  27,  25,  28,  25,  25,  30,  55],
                [ 25,  23,  24,  22,  27,  27,  27,  22,  26,  25,  29,  24,  24,  19,  24,  25,  29,  30,  24,  22,  25,  22,  24,  28,  28,  22,  24,  41,  69,  88,  77,  60],
                [ 23,  25,  26,  28,  24,  24,  28,  24,  29,  29,  27,  25,  24,  28,  24,  26,  25,  23,  25,  25,  23,  25,  21,  27,  26,  26,  27,  53,  52,  65,  96, 111],
                [ 26,  25,  26,  24,  27,  24,  24,  25,  26,  25,  25,  23,  24,  23,  24,  22,  25,  27,  24,  22,  26,  21,  23,  24,  25,  24,  26,  27,  30,  23,  73,  27],
                [ 28,  27,  29,  27,  25,  24,  31,  22,  25,  29,  24,  24,  27,  26,  25,  27,  26,  24,  28,  22,  22,  22,  23,  24,  23,  25,  25,  26,  29,  27,  76,  24],
                [ 24,  29,  23,  27,  28,  29,  29,  28,  25,  27,  25,  25,  22,  27,  27,  23,  27,  24,  24,  22,  24,  24,  24,  22,  26,  23,  24,  24,  26,  27,  37,  24],
                [ 26,  26,  30,  29,  29,  29,  30,  29,  31,  28,  28,  28,  22,  25,  29,  27,  22,  26,  24,  24,  26,  26,  24,  25,  24,  22,  24,  25,  27,  28,  27,  25],
                [ 26,  25,  29,  28,  26,  28,  29,  30,  26,  27,  26,  26,  24,  22,  27,  25,  28,  27,  28,  26,  24,  24,  27,  23,  24,  24,  24,  30,  27,  63,  23,  25],
                [ 27,  29,  27,  27,  27,  26,  27,  29,  30,  27,  29,  29,  25,  27,  27,  25,  26,  24,  26,  27,  28,  28,  25,  27,  22,  26,  22,  39,  24,  74,  25,  32]])

img = img / 255
img = np.moveaxis(np.repeat(img[..., np.newaxis], 3, -1), 2, 0)[np.newaxis,...] # copy the grayscale value 3 times
img = torch.Tensor(img)

mask = np.array([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

mask = mask[np.newaxis,...,np.newaxis]
mask = torch.Tensor(mask)

class Encoder(nn.Module):
    
    def __init__(self, filters, depth, img_channels, kernel_size, padding):

        self.filters = filters
        self.depth = depth
        self.img_channels = img_channels
        self.kernel_size = kernel_size
        self.padding = padding
    
        super().__init__()
        
    def forward(self):
        """Set the contraction path. 
        Depth depends on the parameter depth.
        """
        
        blocks = nn.ModuleList()
        previous_size = self.img_channels
        
        for i, filter_size in enumerate(self.filters[:-1]):
            conv1 = nn.Conv2d(previous_size, filter_size, self.kernel_size, padding=self.padding, bias=False)
            conv2 = nn.Conv2d(filter_size, filter_size, self.kernel_size, padding=self.padding, bias=False)
            pool = nn.MaxPool2d(2) # 2 = kernel size and stride
            relu = nn.ReLU(inplace=True)
            batchnorm = nn.BatchNorm2d(filter_size)
            
            # i want to concatenate featueres from the last conv layer (prior to performing max pool),
            # thus in the first iteration the sub_block looks different
            if i == 0:
                sub_block = [conv1, batchnorm, relu, conv2, batchnorm, relu]
            else:
                sub_block = [pool, conv1, batchnorm, relu, conv2, batchnorm, relu]
                                
            sub_block = nn.Sequential(*sub_block)
            blocks.append(sub_block)
            
            previous_size = filter_size
            
        return blocks.to(device)
        
        
        
class Decoder(nn.Module):
    
    def __init__(self, filters, depth, img_channels, kernel_size, padding, deconv_kernel_size, deconv_stride):

        self.filters = filters
        self.depth = depth
        self.img_channels = img_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.deconv_kernel_size = deconv_kernel_size
        self.deconv_stride = deconv_stride
    
        super().__init__()
                
    def forward(self):
        """Perform the decoding and concatenating"""
        
        blocks = nn.ModuleList()
        previous_size = self.filters[-1]
        
        for filter_size in reversed(self.filters[:-1]):
        
            up = nn.ConvTranspose2d(previous_size, filter_size, kernel_size=self.deconv_kernel_size, stride=self.deconv_stride)
            
            conv1 = nn.Conv2d(previous_size, filter_size, self.kernel_size, padding=self.padding, bias=False)
            conv2 = nn.Conv2d(filter_size, filter_size, self.kernel_size, padding=self.padding, bias=False)
            relu = nn.ReLU(inplace=True)
            batchnorm = nn.BatchNorm2d(filter_size)

            conv_block = [conv1, batchnorm, relu, conv2, batchnorm, relu]
            modules = nn.ModuleList([up, nn.Sequential(*conv_block)])
            
            blocks.append(modules)
            
            previous_size = filter_size
        
        return blocks.to(device)
    
    
class UNet(nn.Module):
    
    def __init__(self, 
                 n_classes,
                 img_channels, 
                 start_filter, 
                 depth,
                 kernel_size = 3,
                 deconv_kernel_size = 2,  
                 deconv_stride = 2,
                 padding = 1,
                 ): 
    
        super().__init__()
        
        self.n_classes = n_classes
        self.start_filter = start_filter
        self.depth = depth
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.filters = [start_filter*(2**i) for i in range(depth+1)]
        
        self.encoder = Encoder(self.filters, depth, img_channels, kernel_size, padding)
        self.decoder = Decoder(self.filters, depth, img_channels, kernel_size, padding, deconv_kernel_size, deconv_stride)
        
        self.output = nn.Conv2d(
            self.filters[0],
            self.n_classes,
            kernel_size=1
        )
    
    def _bottom_path(self, x):
        
        # set the bottom path
        pool = nn.MaxPool2d(2) # 2 = kernel size and stride
        conv1 = nn.Conv2d(self.filters[-2], self.filters[-1], self.kernel_size, padding=self.padding, bias=False)
        conv2 = nn.Conv2d(self.filters[-1], self.filters[-1], self.kernel_size, padding=self.padding, bias=False)
        relu = nn.ReLU(inplace=True)
        batchnorm = nn.BatchNorm2d(self.filters[-1])
        
        bottom_path_layers = nn.Sequential(pool, conv1, batchnorm, relu, conv2, batchnorm, relu).to(device)
        bottom_path_output = bottom_path_layers(x)
        
        return bottom_path_output
    
    
    def forward(self, x):
        
        contract_blocks = self.encoder()
        expand_blocks = self.decoder()
        
        x_contract = []
        
        # iterate over each contract block 
        for contract_block in contract_blocks:
            x_contracted = contract_block(x)
            x_contract.append(x_contracted)
            x = x_contracted
            
        x = self._bottom_path(x)
        
        for i, expand_block in enumerate(expand_blocks):
            # do ConvTranspose2d (i.e. up convolution)
            x_expanded = expand_block[0](x)
            
            # concatenate it with features from the contract path on the channel dimension (axis=1)
            #x_expanded = torch.cat((x_contract[self.depth - (i+1)], x_expanded), axis=1)
            x_expanded = torch.cat((x_expanded, x_contract[self.depth - (i+1)]), axis=1)
            
            # convolve on concatenated features
            x = expand_block[1](x_expanded)
            
        y_pred = self.output(x)
    
        return torch.sigmoid(y_pred)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(n_classes=1, img_channels=3, start_filter=32, depth=4)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters())
epochs = 200

def dice_loss(predicted_mask, true_mask, eps=1e-8):

    # flatten the predicted mask and the true mask
    batch_dim = predicted_mask.shape[0]
    predicted_mask = predicted_mask.view(batch_dim, -1)
    true_mask = true_mask.view(batch_dim, -1)
    
    intersection = (predicted_mask * true_mask).sum(1)
    dice = (2 * intersection) / (predicted_mask.sum(1) + true_mask.sum(1) + eps)
    dice = dice.sum() / batch_dim
    
    # the dice score is between 0 (worst prediction) and 1 (perfect prediction),
    # so the expression 1 - dice will be minimized
    return 1 - dice

img, mask = img.to(device), mask.to(device)

for epoch in tqdm(range(epochs)):
    
    train_epoch_loss = 0

    # zero the parameter gradients
    optimizer.zero_grad()
    
    # forward pass
    y_pred = model(img)        

    # backward pass
    loss = dice_loss(y_pred.to(device), mask.to(device))
    loss.backward()
    
    # optimize
    optimizer.step()

    train_epoch_loss += loss.item()

    #print("EPOCH LOSS:", train_epoch_loss)
    
    
model.eval()

# test loop, deactivate gradiant calculation
with torch.no_grad():
   
    # predict (i.e. use only the forward pass of the model)
    y_pred = model(img)

    # generate the mask
    pred_mask_own_unet = (y_pred > 0.5).type(torch.uint8)[0,0,:,:].cpu().numpy()


####################################### COPIED UNET MODEL ####################################################
# from https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

model2 = UNet()
model2 = model2.to(device)

optimizer = torch.optim.Adam(model2.parameters())

for epoch in tqdm(range(epochs)):
    
    train_epoch_loss = 0

    # zero the parameter gradients
    optimizer.zero_grad()
    
    # forward pass
    y_pred = model2(img)        

    # backward pass
    loss = dice_loss(y_pred.to(device), mask.to(device))
    loss.backward()
    
    # optimize
    optimizer.step()

    train_epoch_loss += loss.item()

    #print("EPOCH LOSS:", train_epoch_loss)
    
    
model2.eval()

# test loop, deactivate gradiant calculation
with torch.no_grad():
   
    # predict (i.e. use only the forward pass of the model)
    y_pred = model2(img)

    # generate the mask
    pred_mask_copied_unet = (y_pred > 0.5).type(torch.uint8)[0,0,:,:].cpu().numpy()


fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10,10))
ax1.imshow((mask[0,:,:,0].cpu().numpy() * 255).astype(np.uint8))
ax2.imshow((pred_mask_own_unet * 255).astype(np.uint8))
ax3.imshow((pred_mask_copied_unet * 255).astype(np.uint8))
ax1.set_title("Label")
ax2.set_title("Predicted mask own UNet")
ax3.set_title("Predicted mask copied UNet")
ax1.axis("off")
ax2.axis("off")
ax3.axis("off")

Ok, I realized what the mistake was. I created the layers in the forward method of the Encoder and Decoder class, instead of the __init__ method, i.e. weights are randomly initialized in each epoch and no learning was possible. This has also influenced the runtime, so it is running faster now.