Neural Network with Evolution Strategies optimizer keeps outputting the same accuracy on MNIST

My task is to create a ANN with an Evolution Strategies algorithm as the optimizer (no derivation). The dataset I am using is MNIST. For now, I am just trying to implement this with a Linear ANN.

I found a Colab notebook that does this exact thing, but on the sklearn “make_moons” dataset. I tried to incorporate what was on the notebook, and the code runs with no problems; yet it outputs the same accuracy. Usually the first few outputs are different, then it “converges” at 0.0987 in the training set and 0.098 in the test set. Additionally, it takes super long to train. Maybe there are redundant iterations?

Colab Notebook, if you want to check it out: Google Colab

I tried trying some StackOverflow recommendations, such as adjusting the hyperparameters (learning rate, hidden units), as well as using Leaky ReLu in case of a “dying ReLu”; none of them worked. This leads me to believe that the problem is in the ES optimizer.

I am new to Pytorch, so if any glaring malpractices are there, please say so!

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
    
# Set decive to CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# NN & DataLoader hyperparameters
input_size = 784
num_classes = 10
learning_rate = 0.01
batch_size = 64
num_epochs = 1 

# Load data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 

test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=False) 
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) 

# Connected NN
model = nn.Sequential(
      nn.Linear(input_size, 40),
      nn.ReLU(0.1),
      nn.Linear(40, 20),
      nn.ReLU(0.1),
      nn.Linear(20, num_classes),
      nn.ReLU(0.1),
)
model = model.float().to(device)

# Custom loss function
loss_func = nn.CrossEntropyLoss()

def loss(y_pred, y_true):
  return 1/loss_func(y_pred, y_true) # We are maximizing the loss in ES, so take the reciprocal
  # Now, increasing loss means the model is learning

# Fitness function
def fitness_func(solution, scores, targets):
  # Solution is a vector of parameters like mother_parameters
  nn.utils.vector_to_parameters(solution, model.parameters())
  return loss(scores, targets)

# In ES, our population is a slightly altered version of the mother parameters, so we implement a jitter function
def jitter(mother_params, state_dict):
  params_try = mother_params + SIGMA*state_dict.to(device)
  return params_try

# Now, we calculate the fitness of entire population
def calculate_population_fitness(pop, mother_vector, scores, targets):
  fitness = torch.zeros(pop.shape[0])
  for i, params in enumerate(pop):
    p_try = jitter(mother_vector, params)
    fitness[i] = fitness_func(p_try, scores, targets)
  return fitness

# Calculating number of parameters
n_params = nn.utils.parameters_to_vector(model.parameters()).shape[0]

# now, implementing the training algorithm
mother_parameters = model.parameters()
mother_vector = nn.utils.parameters_to_vector(mother_parameters)

# ES hyperparameters
SIGMA = 0.01
LR = 0.01
POPULATION_SIZE=50
ITERATIONS = 500 

# Train network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):

        data = data.to(device=device)
        targets = targets.to(device=device)

        # Correcting shape
        data = data.reshape(data.shape[0], -1)

        print(f"{batch_idx} out of {len(train_loader)}")
        
        # ES optimizer
        with torch.no_grad(): # No need for gradients
            for iteration in tqdm(range(ITERATIONS)):
                scores = model(data)
                pop = torch.from_numpy(np.random.randn(POPULATION_SIZE, n_params)).float().to(device)
                fitness = calculate_population_fitness(pop, mother_vector, scores, targets)
                # Normalize the fitness
                normalized_fitness = ((fitness - torch.mean(fitness)) / torch.std(fitness)).to(device)
                # Update mother vector with the fitness values
                mother_vector = mother_vector.to(device) + (LR / (POPULATION_SIZE * SIGMA)) * torch.matmul(pop.t(), normalized_fitness)

        # Update the model parameters
        nn.utils.vector_to_parameters(mother_vector, model.parameters())

        # Computing accuracy
        num_correct = 0
        num_samples = 0

        for x, y in train_loader:
              x = x.to(device=device)
              y = y.to(device=device)
              x = x.reshape(x.shape[0], -1)

              scores = model(x)
              _, predictions = scores.max(1)
              num_correct += (predictions == y).sum()
              num_samples += predictions.size(0)
        
        print(num_correct, num_samples)
        print(f"accuracy {float(num_correct)/float(num_samples)*100:.2f}")
        print("------------------------------------------")