Loss quickly decreases to zero after the first epoch

Hi everyone! I’m currently working on a self-supervised learning model. And while building the model and training it works, every time when I start the training, my loss reduces quickly to zero or almost zero after the first epoch. I don’t quite understand the reason behind it because before I used as loss function the Cross-Entropy loss (without using logSoftmax and NLLLoss) and my loss was constantly at >1 (which wasn’t good either). I know it’s a lot to ask but does anybody have any advice on how to work on this problem? I would appreciate any input! :slight_smile:

information about the training:

Batch_size = 64

Lr = 0.0001 (Tried already 0.001 and 0.1, except that the first epoch has a huge loss nothing really changes)

Size of my training set = 16303 (those are segments; each segment has around 3000 ‘rows’ of information for three different sensor types)

Class imbalance? = Yeah, I have a class imbalance but to counter that, I calculate class weights for each batch

Goal of the network (because it can look strange why I feed three different sensor types into the model): = Classifying for each sensor type which data augmentation was used on that certain type (according to this image: )

Bild2

Here are my model and training’s process:

class BaseModel(nn.Module):

def init(self,
dropout_p=0.2):
super(BaseModel,self).init()

# Kernel Sizes
self.K24 = 24
self.K16 = 16
self.K8 = 8
self.K4 = 4
self.K2 = 2

# Output channels
self.out_features32 = 32
self.out_features64 = 64
self.out_features96 = 96
self.out_features128 = 128

# CNN Block for first sensor type

self.conv_block_acc = Sequential(
  Conv1d(in_channels=3, out_channels=self.out_features32, 
         kernel_size = self.K24, 
         padding = 1),
         nn.ReLU(),
         BatchNorm1d(self.out_features32),

  Conv1d(in_channels=self.out_features32, 
         out_channels=self.out_features64, 
         kernel_size = self.K16, 
         padding = 1),
         nn.ReLU(),      
         BatchNorm1d(self.out_features64),
         
  Conv1d(in_channels=self.out_features64, 
         out_channels=self.out_features96, 
         kernel_size = self.K8, 
         padding = 1),
         nn.ReLU(),      
         BatchNorm1d(self.out_features96)
         )

# CNN Block for second sensor type
self.conv_block_eog = Sequential(
  Conv1d(in_channels=2, out_channels=self.out_features32, 
         kernel_size = self.K24, 
         padding = 1),
         nn.ReLU(),
         BatchNorm1d(self.out_features32),

  Conv1d(in_channels=self.out_features32, 
         out_channels=self.out_features64, 
         kernel_size = self.K16, 
         padding = 1),
         nn.ReLU(),      
         BatchNorm1d(self.out_features64),
         
  Conv1d(in_channels=self.out_features64, 
         out_channels=self.out_features96, 
         kernel_size = self.K8, 
         padding = 1),
         nn.ReLU(),      
         BatchNorm1d(self.out_features96)
         )

# CNN Block for third sensor type
self.conv_block_gyro = Sequential(
  Conv1d(in_channels=3, out_channels=self.out_features32, 
         kernel_size = self.K24, 
         padding = 1),
         nn.ReLU(),
         BatchNorm1d(self.out_features32),

  Conv1d(in_channels=self.out_features32, 
         out_channels=self.out_features64, 
         kernel_size = self.K16, 
         padding = 1),
         nn.ReLU(),      
         BatchNorm1d(self.out_features64),
         
  Conv1d(in_channels=self.out_features64, 
         out_channels=self.out_features96, 
         kernel_size = self.K8, 
         padding = 1),
         nn.ReLU(),      
         BatchNorm1d(self.out_features96)
         )

# CNN for concatenating all sensor types

self.conv_block2 = Sequential(
    Conv1d(in_channels=self.out_features96, 
       out_channels=self.out_features128, 
       kernel_size = self.K4 , 
       padding = 1),
       nn.ReLU(), 
       BatchNorm1d(self.out_features128)
       )

# Global Maxpooling Layer
self.global_max_pool = nn.MaxPool1d(kernel_size = self.K2 , 
                                    stride = 2)

# Droput Layer
self.dropout = Dropout(p=dropout_p)

# Classifer for first sensor and third type (because they have the same amount   
of classes
self.classifer_acc_gyro = Sequential(
    nn.LazyLinear(512),
    #nn.Linear(1534,512),
    nn.ReLU(),        
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 9)
)

# Classifer for second sensor type

self.classifer_eog = Sequential(
    nn.LazyLinear(512),
    #nn.Linear(1534,512),
    nn.ReLU(),        
    nn.Linear(512,256),
    nn.ReLU(), 
    nn.Linear(256, 8)
)

self.softmax_layer = nn.LogSoftmax(dim=1)

def forward(self, x_eog, x_acc, x_gyro):
“”"
x = All sensors
x_eog = EOG
x_acc = ACC
x_gyro = GYRO
“”"

x_eog = torch.reshape(x_eog,
                      (x_eog.shape[0],
                       x_eog.shape[2],
                       x_eog.shape[1]
                      ))


x_acc = torch.reshape(x_acc,
                      (x_acc.shape[0],
                       x_acc.shape[2],
                       x_acc.shape[1]
                      ))
x_gyro = torch.reshape(x_gyro,
                      (x_gyro.shape[0],
                       x_gyro.shape[2],
                       x_gyro.shape[1]
                      ))


# Feeding each sensor to their corresponding CNN block
x_eog = self.conv_block_eog(x_eog)
x_acc = self.conv_block_acc(x_acc)
x_gyro = self.conv_block_gyro(x_gyro)

# Concatenating all sensor types along the batch size
x = torch.cat((x_eog, 
               x_acc, 
               x_gyro), dim = 0)


# Feeding concatenated sensors to CNN block 2
x = self.conv_block2(x)

x = self.global_max_pool(x)

x = self.dropout(x)

# splitting of the sensors again, along the batch size
x_eog, x_acc, x_gyro = torch.split(x, (x.shape[0]//3), dim=0)

#Classifer first sensor type

x_acc = self.classifer_acc_gyro(x_acc)
#x_acc = self.softmax_layer(x_acc)

#Classifer second sensor type

x_eog = self.classifer_eog(x_eog)
#x_eog = self.softmax_layer(x_eog)


#Classifer third sensor type

x_gyro = self.classifer_acc_gyro(x_gyro)
#x_gyro = self.softmax_layer(x_gyro)

#Use Softmax 

x_eog = self.softmax_layer(x_eog)
x_acc = self.softmax_layer(x_acc)
x_gyro = self.softmax_layer(x_gyro)

return x_eog[:, -1], x_acc[:, -1], x_gyro[:, -1]      

for epoch in tqdm.tqdm(range(1, epochs+1)):

— TRAIN AND EVALUATE ON TRAINING SET --------------------------

model.train()

train_loss_eog, train_loss_acc, train_loss_gyro = 0.0, 0.0, 0.0

num_train_correct_eog, num_train_correct_acc, num_train_correct_gyro  = 0, 0, 0

num_train_examples_eog,num_train_examples_acc, num_train_examples_gyro  = 0,0, 0

for batch in train_dataloader:
  
 

  optimizer.zero_grad()

  # Splitting of the data, according to their sensor types
  inputs_eog, targets_eog = batch[0][0], batch[0][1]
  inputs_acc, targets_acc = batch[1][0], batch[1][1]
  inputs_gyro, targets_gyro = batch[2][0], batch[2][1]

  # Calculating the class_weights
  class_weights_eog = calculate_class_weight(targets_eog, sensortype='EOG')
  class_weights_acc = calculate_class_weight(targets_acc, sensortype='ACC')
  class_weights_gyro = calculate_class_weight(targets_gyro, sensortype='GYRO')

  # Feed everything to GPU
  class_weights_eog,class_weights_acc, class_weights_gyro = class_weights_eog.to(device), class_weights_acc.to(device),class_weights_gyro.to(device)

  inputs_eog, targets_eog = inputs_eog.to(device), targets_eog.to(device)
  inputs_acc, targets_acc = inputs_acc.to(device), targets_acc.to(device)
  inputs_gyro, targets_gyro = inputs_gyro.to(device), targets_gyro.to(device)      

  # Feeding the inputs to the model
  predictions_eog, predictions_acc, predictions_gyro = model(x_eog = inputs_eog,
              x_acc = inputs_acc,
              x_gyro = inputs_gyro)
  
  #Calculating the loss
  loss_fn_eog = NLLLoss(weight = class_weights_eog, reduction='mean')
  loss_fn_acc = NLLLoss(weight = class_weights_acc, reduction='mean')
  loss_fn_gyro = NLLLoss(weight = class_weights_gyro, reduction='mean')
  
  # Combine their loss
  total_loss = loss_eog+loss_acc+loss_gyro

  total_loss.backward()

  optimizer.step()

  train_loss_eog += loss_eog.data.item() * inputs_eog.size(0)
  train_loss_acc += loss_acc.data.item() * inputs_acc.size(0)
  train_loss_gyro += loss_gyro.data.item() * inputs_gyro.size(0)

  num_train_correct_eog  += (torch.max(predictions_eog, 1)[1] == targets_eog).sum().item()
  num_train_examples_eog += inputs_eog.shape[0]

  num_train_correct_acc  += (torch.max(predictions_acc, 1)[1] == targets_acc).sum().item()
  num_train_examples_acc += inputs_acc.shape[0]

  num_train_correct_gyro  += (torch.max(predictions_gyro, 1)[1] == targets_gyro).sum().item()
  num_train_examples_gyro += inputs_gyro.shape[0]


train_acc_eog   = num_train_correct_eog / num_train_examples_eog   
train_loss_eog  = train_loss_eog / len(train_dataloader.dataset)

train_acc_acc   = num_train_correct_acc / num_train_examples_acc   
train_loss_acc  = train_loss_acc / len(train_dataloader.dataset)

train_acc_gyro   = num_train_correct_gyro / num_train_examples_gyro  
train_loss_gyro  = train_loss_gyro / len(train_dataloader.dataset)

Could you explain the data shapes a bit more and in particular which shape the outputs of the model have?

return x_eog[:, -1], x_acc[:, -1], x_gyro[:, -1] 

I assume you are dealing with temporal data and want to use the last time step? If so, did you check the class imbalance for these samples and made sure they are not a single class only?

1 Like

Hi ptrblck, thank you for your comment and time!

Indeed, I work on temporal data. In regards to the data shapes:

The dataset object returns the 3 different sensor types (EOG, ACC, Gyro) (or as I comment on the model architecture: first, second and third sensor). During the training process, I split them into three different inputs(= number of sensor types) and their corresponding labels.

The shape for the EOG (first sensor) is: torch.Size([64, 3108, 2]), the shape for the other two is torch.Size([64, 3108, 3]) each. I receive the following shapes inside the model:

def forward(self, x_eog, x_acc, x_gyro):
“”"
x = All sensors
x_eog = EOG
x_acc = ACC
x_gyro = GYRO
“”"
x_eog = torch.reshape(x_eog,
(x_eog.shape[0],
x_eog.shape[2],
x_eog.shape[1]
))
Shape x_eog (EOG): [64, 2, 3108]

x_acc = torch.reshape(x_acc,
(x_acc.shape[0],
x_acc.shape[2],
x_acc.shape[1]
))

Shape x_acc (ACC): [64, 3, 3108]

x_gyro = torch.reshape(x_gyro,
(x_gyro.shape[0],
x_gyro.shape[2],
x_gyro.shape[1]
))

Shape x_gyro (GYRO): [64, 3, 3108]

#Feeding each sensor to their corresponding CNN block
x_eog = self.conv_block_eog(x_eog)
x_acc = self.conv_block_acc(x_acc)
x_gyro = self.conv_block_gyro(x_gyro)

Shape x_eog (EOG): [64, 96, 3069]
Shape x_acc (ACC): [64, 96, 3069]
Shape x_gyro (GYRO): [64, 96, 3069]

#Concatenating all sensor types along the batch size
x = torch.cat((x_eog,
x_acc,
x_gyro), dim = 0)

Shape of X: torch.Size([192, 96, 3069])

#Feeding concatenated sensors to CNN block 2
x = self.conv_block2(x)

Shape of X: torch.Size([192, 128, 3068])

x = self.global_max_pool(x)

Shape of X: torch.Size([192, 128, 1534])

x = self.dropout(x)

Shape of X: torch.Size([192, 128, 1534])

#splitting of the sensors again, along the batch size
x_eog, x_acc, x_gyro = torch.split(x, (x.shape[0]//3), dim=0)

Shape x_eog (EOG):[64, 128, 1534]
Shape x_acc (ACC): [64, 128, 1534]
Shape x_gyro (GYRO): [64, 128, 1534]

#Classifer first sensor type
x_acc = self.classifer_acc_gyro(x_acc)

#Classifer second sensor type
x_eog = self.classifer_eog(x_eog)

#Classifer third sensor type
x_gyro = self.classifer_acc_gyro(x_gyro)

Shape x_eog (EOG):[64, 128, 8]
Shape x_acc (ACC): [64, 128, 9]
Shape x_gyro (GYRO): [64, 128, 9]

#Use Softmax

x_eog = self.softmax_layer(x_eog)
x_acc = self.softmax_layer(x_acc)
x_gyro = self.softmax_layer(x_gyro)

**Shapes are the same as after the classifier *

return x_eog[:, -1], x_acc[:, -1], x_gyro[:, -1]
Shape x_eog (EOG):[64, 8]
Shape x_acc (ACC): [64, 9]
Shape x_gyro (GYRO): [64, 9]

return x_eog[:, -1], x_acc[:, -1], x_gyro[:, -1]

I assume you are dealing with temporal data and want to use the last time step? 

The reason why I return my outputs like that without that I get the error (Example for the first sensor) :

Expected target size [64, 8], got [64]

when I want to calculate the loss for each sensor. I couldn’t figure another way. But I want to use all time steps not only the last.

If so, did you check the class imbalance for these samples and made sure they are not a single class only?

I checked them right now and no sample has only a single class (sometimes, some of the classes are missing, but no sample has only one class).

Thank you so much again for your time!


Just in case, I add my dataset class here for more information (about the idea of this dataset: the dataset is unlabelled. The idea is to split the sensor types of the dataset into three separate variables, run data augmentation on each sensor type and give them each a label, according to the performed transformation. The tasks for the model afterwards is to do an 8 class classification (for the first sensor, EOG) and 9 class classifications each for the other two):

classDataset(Dataset):
def init(self, data, max_sequence_length, feature_scaling):

self.data = data #<- Später das np.load hier reinmachen

self.max_sequence_length = max_sequence_length

self.available_transformations = {
      "None":[],
      "Noise_Addition": [x * 0.1 for x in range(0, 5)][1:], 
      "Scale": list(range(5,11))
      ,
      "Vertical_Flip":[],
      "Horizontal_Flip":[],
      "Permutation":[],
      "Time_Warp": [2],
      "Channel_Shuffle": ['EOG','ACC','GYRO'],
      "Rotation": []
}

self.feature_scaling = feature_scaling

def apply_augmentation(self, signal, augmentation):

  if augmentation == 'Noise_Addition':

    amount_of_noise = np.random.choice(list(self.available_transformations.values())[1])
    signal_aug = signal_Transformationen.add_random_noise(signal,
                                                      noise_amount=amount_of_noise)
  
  elif augmentation == 'Scale':
    
    scale_factor = np.random.choice(list(self.available_transformations.values())[2])
    signal_aug = signal_Transformationen.scaling(signal,
                                                 factor = scale_factor)
  
  elif augmentation == 'Vertical_Flip':
    signal_aug = signal_Transformationen.vertical_flip(signal)

  elif augmentation == 'Horizontal_Flip':
    signal_aug = signal_Transformationen.horizontal_flip(signal)

  elif augmentation == 'Permutation':
    signal_aug = signal_Transformationen.permutation(signal)

  elif augmentation == 'Time_Warp':
    signal_aug = signal_Transformationen.time_warp(signal,
                                                   factor=list(self.available_transformations.values())[6][0])
    
  elif augmentation == 'Channel_Shuffle':
    signal_aug = signal_Transformationen.channel_shuffle(signal)

  elif augmentation == 'Rotation':
    signal_aug = signal_Transformationen.rotation(signal)

  else:
    signal_aug = signal

  return signal_aug

def len(self):
return len(self.data)

def getitem(self,index):

x_org = self.data[index]


# Pro Segment wird nur eine Transformation ausgeführt
# Jede Transformation hat dabei die gleiche Chance. Das heißt, im Falle der EOG Sensoren, liegt die Chance bei jeweils 12.5%
# (7 Augmentations + keine Augmentation)
# Bei ACC und GYRO sinds jeweils 11.5 (8 Augmentations + keine Augmentation)
# Chancen, welche Sensorart betroffen ist: EOG: 12.5%, ACC & GYRO: jeweils 11.5, überhaupt keine Transformation: 

x_aug_eog = x_org[:,0:2]
x_aug_acc = x_org[:,2:5]
x_aug_gyro = x_org[:,5:8]


# Auswahl der Data Augmentation
  # Für ACC und Gyro:

name_of_augmentation_acc_gyro = np.random.choice(list(self.available_transformations.keys()))
  
  #Für EOG

name_of_augmentation_eog = np.random.choice(list(self.available_transformations.keys())[0:8])

#Durchführung der Augmentation und Abspeichern des Labels

x_aug_eog = self.apply_augmentation(x_aug_eog, name_of_augmentation_eog)
label_eog = list(self.available_transformations)[0:8].index(name_of_augmentation_eog)

x_aug_acc = self.apply_augmentation(x_aug_acc, name_of_augmentation_acc_gyro)
label_acc = list(self.available_transformations).index(name_of_augmentation_acc_gyro)    

x_aug_gyro = self.apply_augmentation(x_aug_gyro, name_of_augmentation_acc_gyro)
label_gyro = list(self.available_transformations).index(name_of_augmentation_acc_gyro) 

#Umwandlung der Daten in jeweils einen Tensor
x_aug_eog = NormalizeData(x_aug_eog, scaling_type = self.feature_scaling)

x_aug_acc = NormalizeData(x_aug_acc, scaling_type = self.feature_scaling)

x_aug_gyro = NormalizeData(x_aug_gyro, scaling_type = self.feature_scaling)

# Padding der Daten, um die unterschiedlichen Windows sizes zu beachten für jede Sensorart (Bei späteren Zusammenführen kommt es ansonsten zu Problemen, weil das 3 verschiedene Batches sind)
if x_aug_eog.shape[0] < self.max_sequence_length:
  x_aug_eog = np.pad(x_aug_eog,[((self.max_sequence_length-len(x_aug_eog),0)),(0,0)])

  #torch.from_numpy(x_aug_eog.copy())).float()

if x_aug_acc.shape[0] < self.max_sequence_length:
  x_aug_acc = np.pad(x_aug_acc,[((self.max_sequence_length-len(x_aug_acc),0)),(0,0)])
  
if x_aug_gyro.shape[0] < self.max_sequence_length:
  x_aug_gyro = np.pad(x_aug_gyro,[((self.max_sequence_length-len(x_aug_gyro),0)),(0,0)])

return (torch.from_numpy(x_aug_eog.copy()).float(),label_eog) , (torch.from_numpy(x_aug_acc.copy()).float(),label_acc) , (torch.from_numpy(x_aug_gyro.copy()).float(),label_gyro)

Thanks for the follow-up.

A few clarifications:

  • the reshape operation looks wrong as it seems you want to permute the dimensions instead;
x_eog = torch.reshape(x_eog,
(x_eog.shape[0],
x_eog.shape[2],
x_eog.shape[1]
))

(same for the other sensors)
If you want to change the order of dimensions, use x = x.permute(0, 2, 1) instead unless you really want to interleave the tensor.

The error is raised since the target would also need to contain the labels for each step.
I.e. for nn.CrossEntropyLoss or nn.NLLLoss the model output should have the shape [batch_size, nb_classes, seq_len] and the target [batch_size, seq_len] for a multi-class classification on temporal data. You would thus have to permute the model outputs again via out = out.permute(0, 2, 1).contiguous() and you have to make sure your target has all the needed labels.