Multi-Input Network PyTorch-Lightning (tabular and images)

Good morning,
I am trying to implement a model in Pytorch lightning, as in here, capable of predicting the output of a system that simultaneously processes data from tables and images.
Specifically, the tabular data is organized into N rows (the number is very large and I think irrelevant to the question) and 10 columns, each representing a variable in a system. On the other hand, the images are N vectors (1, 3200) (which I consider to be images although 1D because they may be 2D in the future), each being associated with a row in the table. The inputs to my survey are therefore an image (1, 3200) and a table row (1, 10).
The model I wrote, continues to give constant output despite changing inputs. What am I doing wrong?
Thanks in advance

class LitClassifier(pl.LightningModule):
    def __init__(
        self, DB, lr: float = 1e-3, num_workers: int = 8, batch_size: int = 32):
        super().__init__()
        self.lr = lr
        self.num_workers = num_workers
        self.batch_size = batch_size
        
        self.DB = DB
        
        self.conv1 = conv_block(3, 32)
        self.conv2 = conv_block(32, 64)
        self.conv3 = conv_block(64, 64)

        self.ln1 = nn.Linear(25344, 22)
        self.relu = nn.ReLU()
        self.batchnorm = nn.BatchNorm1d(22)
        self.dropout = nn.Dropout2d(0.5)
        self.ln2 = nn.Linear(22, 4) 
        
        self.ln4 = nn.Linear(10, 64)
        self.ln5 = nn.Linear(64, 4)
        
        self.ln9 = nn.Linear(8, 1) 
      
        self.train_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []
        
    def forward(self, img, tab):
       
        img = self.conv1(img)
        img = self.conv2(img)
        img = self.conv3(img)
        img = img.reshape(img.shape[0], -1)
        
        img = self.ln1(img)
        img = self.relu(img)
        img = self.batchnorm(img)
        img = self.dropout(img)
        img = self.ln2(img)
        img = self.relu(img)

        tab = self.ln4(tab)
        tab = self.relu(tab)
        tab = self.ln5(tab)
        tab = self.relu(tab)
        
        x = torch.cat((img, tab), dim=1)
        x = self.relu(x)

        return self.ln9(x)

    def training_step(self, batch, batch_idx):
        image, tabular, y = batch

        criterion = torch.nn.L1Loss()
        y_pred = torch.flatten(self(image, tabular))
        y_pred = y_pred.double()
        y = y.double()
        
        loss = criterion(y_pred, y)
        
        self.train_step_outputs.append(loss)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        image, tabular, y = batch

        criterion = torch.nn.L1Loss()
        y_pred = torch.flatten(self(image, tabular))
        y_pred = y_pred.double()
        y = y.double()
        
        val_loss = criterion(y_pred, y)
        
        self.validation_step_outputs.append(val_loss)
        
        return val_loss
    
    def test_step(self, batch, batch_idx):
        image, tabular, y = batch

        criterion = torch.nn.L1Loss()
        y_pred = torch.flatten(self(image, tabular))
        y_pred = y_pred.double()
        y = y.double()
        
        test_loss = criterion(y_pred, y)
        
        self.test_step_outputs.append(test_loss)
        
        return test_loss
    
    def setup(self, stage):

        bending_data = self.DB

        train_size = int(round(len(bending_data) * 0.8))
        val_size = int((len(bending_data) - train_size) // 2)
        test_size = len(bending_data) - train_size - val_size
       
        self.train_set, self.val_set, self.test_set = random_split(bending_data, (train_size, val_size, test_size))
        
    def on_validation_epoch_end(self):
       epoch_average = torch.stack(self.validation_step_outputs).mean()
       self.log("val_loss", epoch_average)
       print(f"val_loss: {epoch_average}")
       print(f"current epoch: {self.current_epoch}")
       self.validation_step_outputs.clear()  # free memory
       
    
    def on_test_epoch_end(self):
        avg_loss = torch.stack(self.test_step_outputs).mean()
        self.log("test_loss", avg_loss)
        print(f"\ntest_loss: {avg_loss}")
        print(f"current epoch: {self.current_epoch}")
        self.test_step_outputs.clear()
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=(self.lr))

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size)

Could you check if y_pred.double() does break backward pass?
In older versions, casting was not allowed.

Also you could check if gradients are not none when you backwards to be sure there is no issue, you can achieve this by using some lightning’s callbacks or disabling automatic optimization and running backwards manually.