Event Based Vision Data to regress rotation matrix with ConvGRU Architecture but Model isn't able to learn the data

Hello everybody, Iam trying to build a neural network to regress the spin of a table tennis ball. I am implementing a ConvGruNetwork, because I am working with ball trajectories which are organized into time bins. The input data is transformed into a voxel grid, which looks like this:

The network is implemented like this:

import torch
from torch import nn
import numpy as np

from torch.nn import init

class ConvGruCell(nn.Module):
    def __init__(self, input_channels:int, hidden_channels:int, kernel_size:int, padding:int, bias:bool):
        super().__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.bias = bias
        
        # define GRU Gates as convolutions to reduce amount of parameters
        self.reset_gate_conv = nn.Conv2d(self.input_channels + self.hidden_channels, self.hidden_channels, kernel_size=kernel_size, padding=self.padding, bias=self.bias)
        self.update_gate_conv = nn.Conv2d(self.input_channels + self.hidden_channels, self.hidden_channels, kernel_size=kernel_size, padding=self.padding, bias=self.bias)
        self.out_gate_conv = nn.Conv2d(self.input_channels + self.hidden_channels, self.hidden_channels, kernel_size=kernel_size, padding=self.padding, bias=self.bias)
    
        self._init_weights()

    
        self.c = 0
    def forward(self, x, prev_state=None):        
        # stack the previous output state on the current input
        batch_size , _, h, w = x.shape
        if prev_state is None:
            self.c += 1
            print(self.c)
            prev_state = torch.zeros(batch_size, self.hidden_channels, h, w, device=x.device)
            
        stacked_inputs = torch.cat([x, prev_state], dim=1)
        
        # Apply reset, update, and output gates
        reset_gate = torch.sigmoid(self.reset_gate_conv(stacked_inputs))  # Sigmoid activation for reset gate
        update_gate = torch.sigmoid(self.update_gate_conv(stacked_inputs))  # Sigmoid activation for update gate
        out_gate = torch.tanh(self.out_gate_conv(torch.cat([x, reset_gate * prev_state], dim=1)))  # Output gate with tanh

        # Calculate the new state
        new_state = (1 - update_gate) * prev_state + update_gate * out_gate
        
        return new_state
    
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for conv_layer in [self.reset_gate_conv, self.update_gate_conv, self.out_gate_conv]:
            nn.init.xavier_uniform(conv_layer.weight)
            if conv_layer.bias is not None:
                nn.init.zeros_(conv_layer.bias)
    
class ConvGru(nn.Module):
    def __init__(self, num_layers:int, input_channels:int, kernel_size:int, padding:int, bias:bool):
        super().__init__()
        self.num_layers = num_layers
        self.input_channels = input_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.bias = bias
        self.cells = nn.ModuleList(
            [ConvGruCell(input_channels=self.input_channels, hidden_channels=self.input_channels , kernel_size=self.kernel_size, padding=self.padding, bias=self.bias) for _ in range(self.num_layers)]
        )
        self.prev_state = [
                            torch.zeros(32, self.input_channels, 110, 110, dtype=torch.float32)
                                for _ in range(self.num_layers)
                        ]
        
        self.init = False
        self.c = 0
    def forward(self, x):
        time_bins, batch_size, c, h, w = x.shape
        # generate hidden state 
        if not self.init:
            self.init = True
            self.prev_state = [h.to(x.device) for h in self.prev_state]
            self.c += 1
            print(self.c)
            
           
            
        hidden_states = [h0 for h0 in self.prev_state]
            
        current_input = x
        for layer in range(self.num_layers):
            cell = self.cells[layer]
            h_t = hidden_states[layer]
                
            layer_output = []
                
            for t in range(time_bins):
                x_t = current_input[t,:,:,:]
                    
                h_t = cell(x_t, h_t)
                layer_output.append(h_t)
                
            
            
            layer_output = torch.stack(layer_output)
                
            hidden_states[layer] = h_t
                
            current_input = layer_output
            
        output = current_input[-1]
        # hidden_states = torch.stack(hidden_states, dim=1)
            
        return output#, hidden_states  


import torch
from torch import nn
from typing import Tuple
import math
from src.models.components.subcomponents.conv_gru_scheerlink import ConvGru
from src.models.components.subcomponents.conv_gru import ResidualBlock

class FireNet(nn.Module):
    def __init__(self, 
                 input_size:Tuple[int,int], 
                 input_channels:int,
                 output_channels:int,
                 amount_of_output_neurons:int, 
                 act_fun:nn,
                 recurrent_kernel_size:int,
                 recurrent_padding:int,
                 dropout,
                 max_size:int):
        super().__init__() 
        self.input_size = input_size
        self.input_channels = input_channels
        self.amount_of_output_neurons = amount_of_output_neurons
        
        self.max_size = max_size
        self.output_channels = output_channels
        
        self.recurrent_kernel_size = recurrent_kernel_size
        self.recurrent_padding = recurrent_padding
        
        self.act_fun = act_fun
        self.dropout = dropout
        
        
        self.conv1 = nn.Conv2d(self.input_channels, out_channels=self.output_channels, kernel_size=3, padding=1)
        
        
        
        self.convgru1 = ConvGru(input_channels=self.output_channels, 
                               kernel_size=self.recurrent_kernel_size, 
                               padding=self.recurrent_padding, 
                               num_layers=1,
                               bias=True)

        
        self.tail = nn.Sequential(
            nn.Flatten(),
            # nn.BatchNorm1d(self.conv_output_channels * self.input_size[0] * self.input_size[1]),
            self.dropout,
            nn.Linear(in_features=self.output_channels * 110  * 110, 
                      out_features=self.amount_of_output_neurons)
        )
    
    def svd_postprocess(self, out_matrix:torch.Tensor)-> torch.Tensor:
        """postprocess output matrix to closest rotation matrix in SO(3)

        Args:
            out_matrix (torch.Tensor): output matrix of neural network

        Returns:
            torch.Tensor: closest rotation matrix in SO(3)
        """
        U, _, Vh = torch.linalg.svd(out_matrix, full_matrices=False)
        
        # compute determinant of U @ V^T
        det = torch.linalg.det(U @ Vh)
        ones = torch.ones_like(det)
        diagonal_elements = torch.stack([ones,ones,det], dim=-1)
        
        
        S = torch.diag_embed(diagonal_elements)
        
        return U @ S @ Vh
            
    def forward(self, x:torch.Tensor):
        # B x T x C x H x W
        x = x.permute(1,0,2,3,4).float()
        # T x B x C x H x W
        # out = [self.conv1(x_i) for x_i in t] for x in
      
        out= self.convgru1(x)

        out = torch.relu(out)
        
        # out = self.convgru2(out)

        out = self.tail(out)
       
        # transform 9d output vector into 3x3 output matrix M
        out = out.view(-1,3,3) # B x 3 x 3
        # compute closest rotation matrix R in SO(3) 
        out = self.svd_postprocess(out)
        
        return out    

The val and training loss looks like this:

Batch Size: 32
LR: 0.0001
L2 Penalty: 0.4
Dropout:0.5
Optimizer: Adam
Num Layers: 2
Parameters: 108K
Data: 2000 Samples
Loss Function: L1 Loss

The output is a rotation matrix in SO3 and the labels are also rotation matrices

I know very limited data, but I can’t collect more data, therefore I apply data augmentation (time inversion and horizontal flipping)

I’ve tried multiple different learning rate in the interval [0.0001, 0.01], the dropout rate from [0.2,0.5],different amount of depths {1,2,5,20, 40}, different with/ without pooling but the network doesn’t improve at all. No matter how I adapt it.

Thank you a lot for your help!