Hello everybody, Iam trying to build a neural network to regress the spin of a table tennis ball. I am implementing a ConvGruNetwork, because I am working with ball trajectories which are organized into time bins. The input data is transformed into a voxel grid, which looks like this:
The network is implemented like this:
import torch
from torch import nn
import numpy as np
from torch.nn import init
class ConvGruCell(nn.Module):
def __init__(self, input_channels:int, hidden_channels:int, kernel_size:int, padding:int, bias:bool):
super().__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.padding = padding
self.bias = bias
# define GRU Gates as convolutions to reduce amount of parameters
self.reset_gate_conv = nn.Conv2d(self.input_channels + self.hidden_channels, self.hidden_channels, kernel_size=kernel_size, padding=self.padding, bias=self.bias)
self.update_gate_conv = nn.Conv2d(self.input_channels + self.hidden_channels, self.hidden_channels, kernel_size=kernel_size, padding=self.padding, bias=self.bias)
self.out_gate_conv = nn.Conv2d(self.input_channels + self.hidden_channels, self.hidden_channels, kernel_size=kernel_size, padding=self.padding, bias=self.bias)
self._init_weights()
self.c = 0
def forward(self, x, prev_state=None):
# stack the previous output state on the current input
batch_size , _, h, w = x.shape
if prev_state is None:
self.c += 1
print(self.c)
prev_state = torch.zeros(batch_size, self.hidden_channels, h, w, device=x.device)
stacked_inputs = torch.cat([x, prev_state], dim=1)
# Apply reset, update, and output gates
reset_gate = torch.sigmoid(self.reset_gate_conv(stacked_inputs)) # Sigmoid activation for reset gate
update_gate = torch.sigmoid(self.update_gate_conv(stacked_inputs)) # Sigmoid activation for update gate
out_gate = torch.tanh(self.out_gate_conv(torch.cat([x, reset_gate * prev_state], dim=1))) # Output gate with tanh
# Calculate the new state
new_state = (1 - update_gate) * prev_state + update_gate * out_gate
return new_state
def _init_weights(self):
"""Initialize weights using Xavier initialization."""
for conv_layer in [self.reset_gate_conv, self.update_gate_conv, self.out_gate_conv]:
nn.init.xavier_uniform(conv_layer.weight)
if conv_layer.bias is not None:
nn.init.zeros_(conv_layer.bias)
class ConvGru(nn.Module):
def __init__(self, num_layers:int, input_channels:int, kernel_size:int, padding:int, bias:bool):
super().__init__()
self.num_layers = num_layers
self.input_channels = input_channels
self.kernel_size = kernel_size
self.padding = padding
self.bias = bias
self.cells = nn.ModuleList(
[ConvGruCell(input_channels=self.input_channels, hidden_channels=self.input_channels , kernel_size=self.kernel_size, padding=self.padding, bias=self.bias) for _ in range(self.num_layers)]
)
self.prev_state = [
torch.zeros(32, self.input_channels, 110, 110, dtype=torch.float32)
for _ in range(self.num_layers)
]
self.init = False
self.c = 0
def forward(self, x):
time_bins, batch_size, c, h, w = x.shape
# generate hidden state
if not self.init:
self.init = True
self.prev_state = [h.to(x.device) for h in self.prev_state]
self.c += 1
print(self.c)
hidden_states = [h0 for h0 in self.prev_state]
current_input = x
for layer in range(self.num_layers):
cell = self.cells[layer]
h_t = hidden_states[layer]
layer_output = []
for t in range(time_bins):
x_t = current_input[t,:,:,:]
h_t = cell(x_t, h_t)
layer_output.append(h_t)
layer_output = torch.stack(layer_output)
hidden_states[layer] = h_t
current_input = layer_output
output = current_input[-1]
# hidden_states = torch.stack(hidden_states, dim=1)
return output#, hidden_states
import torch
from torch import nn
from typing import Tuple
import math
from src.models.components.subcomponents.conv_gru_scheerlink import ConvGru
from src.models.components.subcomponents.conv_gru import ResidualBlock
class FireNet(nn.Module):
def __init__(self,
input_size:Tuple[int,int],
input_channels:int,
output_channels:int,
amount_of_output_neurons:int,
act_fun:nn,
recurrent_kernel_size:int,
recurrent_padding:int,
dropout,
max_size:int):
super().__init__()
self.input_size = input_size
self.input_channels = input_channels
self.amount_of_output_neurons = amount_of_output_neurons
self.max_size = max_size
self.output_channels = output_channels
self.recurrent_kernel_size = recurrent_kernel_size
self.recurrent_padding = recurrent_padding
self.act_fun = act_fun
self.dropout = dropout
self.conv1 = nn.Conv2d(self.input_channels, out_channels=self.output_channels, kernel_size=3, padding=1)
self.convgru1 = ConvGru(input_channels=self.output_channels,
kernel_size=self.recurrent_kernel_size,
padding=self.recurrent_padding,
num_layers=1,
bias=True)
self.tail = nn.Sequential(
nn.Flatten(),
# nn.BatchNorm1d(self.conv_output_channels * self.input_size[0] * self.input_size[1]),
self.dropout,
nn.Linear(in_features=self.output_channels * 110 * 110,
out_features=self.amount_of_output_neurons)
)
def svd_postprocess(self, out_matrix:torch.Tensor)-> torch.Tensor:
"""postprocess output matrix to closest rotation matrix in SO(3)
Args:
out_matrix (torch.Tensor): output matrix of neural network
Returns:
torch.Tensor: closest rotation matrix in SO(3)
"""
U, _, Vh = torch.linalg.svd(out_matrix, full_matrices=False)
# compute determinant of U @ V^T
det = torch.linalg.det(U @ Vh)
ones = torch.ones_like(det)
diagonal_elements = torch.stack([ones,ones,det], dim=-1)
S = torch.diag_embed(diagonal_elements)
return U @ S @ Vh
def forward(self, x:torch.Tensor):
# B x T x C x H x W
x = x.permute(1,0,2,3,4).float()
# T x B x C x H x W
# out = [self.conv1(x_i) for x_i in t] for x in
out= self.convgru1(x)
out = torch.relu(out)
# out = self.convgru2(out)
out = self.tail(out)
# transform 9d output vector into 3x3 output matrix M
out = out.view(-1,3,3) # B x 3 x 3
# compute closest rotation matrix R in SO(3)
out = self.svd_postprocess(out)
return out
The val and training loss looks like this:
Batch Size: 32
LR: 0.0001
L2 Penalty: 0.4
Dropout:0.5
Optimizer: Adam
Num Layers: 2
Parameters: 108K
Data: 2000 Samples
Loss Function: L1 Loss
The output is a rotation matrix in SO3 and the labels are also rotation matrices
I know very limited data, but I can’t collect more data, therefore I apply data augmentation (time inversion and horizontal flipping)
I’ve tried multiple different learning rate in the interval [0.0001, 0.01], the dropout rate from [0.2,0.5],different amount of depths {1,2,5,20, 40}, different with/ without pooling but the network doesn’t improve at all. No matter how I adapt it.
Thank you a lot for your help!