Problem of freeze metrics after first epoch

Hello everyone. I encountered a problem with metrics fading after the first training epoch. To solve this issue, I implemented a 1D-CNN:

class Model(nn.Module):

    def __init__(self, num_classes=8705, input_length=224):

        super(Model, self).__init__()

        self.act_fn = nn.LeakyReLU(negative_slope=0.05)

        self.conv1_a = nn.Conv1d(1, 32, kernel_size=1)

        self.bn1_a = nn.BatchNorm1d(32)


        self.conv1_b = nn.Conv1d(1, 32, kernel_size=3, padding=1)

        self.bn1_b = nn.BatchNorm1d(32)

        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2_a = nn.Conv1d(64, 32, kernel_size=1)

        self.bn2_a = nn.BatchNorm1d(32)

        self.conv2_b = nn.Conv1d(64, 32, kernel_size=3, padding=1)

        self.bn2_b = nn.BatchNorm1d(32)

        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3_a = nn.Conv1d(64, 32, kernel_size=1)

        self.bn3_a = nn.BatchNorm1d(32)

        self.conv3_b = nn.Conv1d(64, 32, kernel_size=3, padding=1)

        self.bn3_b = nn.BatchNorm1d(32)

        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv4_a = nn.Conv1d(64, 32, kernel_size=1)

        self.bn4_a = nn.BatchNorm1d(32)

        self.conv4_b = nn.Conv1d(64, 32, kernel_size=3, padding=1)

        self.bn4_b = nn.BatchNorm1d(32)

        self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv5_a = nn.Conv1d(64, 32, kernel_size=1)

        self.bn5_a = nn.BatchNorm1d(32)

        self.conv5_b = nn.Conv1d(64, 32, kernel_size=3, padding=1)

        self.bn5_b = nn.BatchNorm1d(32)

        self.pool5 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.final_conv = nn.Conv1d(64, num_classes, kernel_size=1)

        self.final_bn = nn.BatchNorm1d(num_classes)

        self.gap = nn.AdaptiveAvgPool1d(1)

        

    def forward(self, x):

        b1_a = self.act_fn(self.bn1_a(self.conv1_a(x)))

        b1_b = self.act_fn(self.bn1_b(self.conv1_b(x)))

        x = torch.cat([b1_a, b1_b], dim=1) # Concat по каналам (32+32=64)

        x = self.pool1(x)

      
        b2_a = self.act_fn(self.bn2_a(self.conv2_a(x)))

        b2_b = self.act_fn(self.bn2_b(self.conv2_b(x)))

        x = torch.cat([b2_a, b2_b], dim=1)

        x = self.pool2(x)

        b3_a = self.act_fn(self.bn3_a(self.conv3_a(x)))

        b3_b = self.act_fn(self.bn3_b(self.conv3_b(x)))

        x = torch.cat([b3_a, b3_b], dim=1)

        x = self.pool3(x)

        b4_a = self.act_fn(self.bn4_a(self.conv4_a(x)))

        b4_b = self.act_fn(self.bn4_b(self.conv4_b(x)))

        x = torch.cat([b4_a, b4_b], dim=1)

        x = self.pool4(x)

    
        b5_a = self.act_fn(self.bn5_a(self.conv5_a(x)))

        b5_b = self.act_fn(self.bn5_b(self.conv5_b(x)))

        x = torch.cat([b5_a, b5_b], dim=1)

        x = self.pool5(x)


        x = self.final_conv(x)

        x = self.final_bn(x)

        x = self.act_fn(x)

        x = self.gap(x)

        x = x.squeeze(-1)

        return x

The model takes input data in the form of a 1x224 vector. The model must work with 8705 classes. The data is implemented in two forms: 8705*3600 in training and 8705 in validation. The training set is also divided into 80/20 proportions of 3600, while maintaining 8705 classes. I use SGD, CrossEntropyLoss, and Step in training. LR = 0.01, Momentum = 0.9, weight_decay = 0.00005, lr_step_size = 10000, lr_gamma = 0.01

def train_epoch(self, epoch, total_epochs):

        self.model.train()

        running_loss = 0.0

        correct = 0

        total = 0


        pbar = tqdm(

            self.train_loader, 

            desc=f"Ep {epoch}/{total_epochs}",

            leave=False,

            ncols=80,

            bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{postfix}]'

        )

        for batch_idx, (data, targets) in enumerate(pbar):

            data, targets = data.to(self.device), targets.to(self.device)

            self.optimizer.zero_grad()

            outputs = self.model(data)

            loss = self.criterion(outputs, targets)

            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()


            batch_loss = loss.item()

            _, predicted = torch.max(outputs.data, 1)

            batch_correct = (predicted == targets).sum().item()

            batch_total = targets.size(0)


            running_loss += batch_loss

            correct += batch_correct

            total += batch_total


            avg_loss = running_loss / (batch_idx + 1)

            avg_acc = 100 * correct / total


            pbar.set_postfix_str(

                f"Loss={avg_loss:.4f}, Acc={avg_acc:.2f}%, LR={self.optimizer.param_groups[0]['lr']:.1e}"

            )


            self.history['batch_loss'].append(avg_loss)

            self.history['batch_acc'].append(avg_acc)

        return running_loss / len(self.train_loader), 100 * correct / total

Can you tell me what the problem is, because I’ve been struggling with this issue for a long time. I’ve tried switching from SGD to AdamW optimizer, and the metrics have improved significantly, but I find it suspicious.

Thanks

Hi there! It is very common for standard SGD to stall out or “freeze” early in training compared to AdamW, especially in deep convolutional models. This happens because SGD applies the same global learning rate to all parameters. If your gradients become very small (or very large), SGD struggles to make meaningful updates, causing your metrics to plateau. AdamW, on the other hand, adaptively scales the learning rate for each parameter individually based on past gradients, making it much more robust to suboptimal learning rates or initializations.

Using AdamW is completely normal and often the recommended default for modern architectures! However, if you really want or need to get SGD working, here are a few things you can tweak:

  1. Learning Rate: 1e-2 might be too high or too low depending on your batch size. You could try doing a quick learning rate sweep (e.g., trying 1e-1, 1e-3, 1e-4).
  2. Initialization: Custom weight initialization (like Kaiming normal/uniform) is crucial for LeakyReLU networks when using SGD, otherwise the forward activations or backward gradients can vanish/explode in the first epoch.

python

def init_weights(m):

if isinstance(m, nn.Conv1d):

nn.init.kaiming_normal_(m.weight, mode=‘fan_out’, nonlinearity=‘leaky_relu’, a=0.05)

elif isinstance(m, nn.BatchNorm1d):

nn.init.constant_(m.weight, 1)

nn.init.constant_(m.bias, 0)

model.apply(init_weights)

  1. Check your Validation Loop: Ensure you are calling self.model.eval() before your validation loop and using with torch.no_grad(): to prevent gradients from accumulating in validation.

Overall, it is not suspicious at all that AdamW fixed your issue—that is exactly what adaptive optimizers were designed to do! I’d recommend sticking with AdamW if it gives you the best results.

CommentCtrl+Alt+M