Great difference in results between runs

I did a Pytorch implementation for some paper (the code below), and I have an issue that I’m not able to solve.
In some runs, the NN reproduces the results reported in the paper (0.97 accuracy) and in some runs, the results are much worse (0.82).
I tried to figure out where exactly is the problem and the only thing that I found is that there is a range of seeds that bring good results (0.97) and there are some seeds that don’t. I didn’t manage to realize what exactly is the difference the only thing that I could think of is the weights initialization or that I have a bug in my code that I cant see.

The dataset I train on is MNIST

Here is my code:

class SpectralNet(nn.Module):
    def __init__(self, input_dim, architecture):
        super(SpectralNet, self).__init__()
        self.input_dim = input_dim
        self.architecture = architecture
        self.layers = nn.ModuleList()
        self.num_of_layers = self.architecture['num_of_layers']

        current_dim = self.input_dim
        for i in range(1, self.num_of_layers):
            next_dim = self.architecture[f"layer{i}"]
            if i == self.num_of_layers - 1:
                layer = nn.Sequential(nn.Linear(current_dim, next_dim), nn.Tanh())
            else:
                layer = nn.Sequential(nn.Linear(current_dim, next_dim), nn.ReLU())
            
            self.layers.append(layer)
            current_dim = next_dim
        self.apply(init_weights)

    def forward(self, x, orthonorm_step=True):
        for layer in self.layers:
            x = layer(x)
        
        if orthonorm_step:
            L = torch.linalg.cholesky(torch.mm(torch.t(x), x), upper=False)
            self.orthonorm_weights = np.sqrt(x.shape[0]) * torch.t(torch.inverse(L))
                
        return torch.mm(x,self.orthonorm_weights)


class SpectralNetLoss(nn.Module):
    def __init__(self):
        super(SpectralNetLoss, self).__init__()

    def forward(self, W, Y, normalized=False):
        if normalized:
            D = W.sum(1)
            Y = Y / D[:, None]

        Dis_y = torch.cdist(Y,Y, p=2) ** 2
        return torch.sum(Dis_y*W) / (W.shape[0])


class SpectralNetOperations():
    def __init__(self, model, dataset, device, siamese_net=None):
        self.model = model
        self.dataset = dataset
        self.device = device
        self.siamese_net = siamese_net
        self.loss_function = SpectralNetLoss()
        self.lr = SPECTRAL_HYPERPARAMS[self.dataset]["initial_lr"]
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
        self.counter = 0


    def ortho_step(self, x):
        self.model.eval()
        return self.model(x, orthonorm_step=True)

    def train_step(self, x):
        self.model.train()
        self.optimizer.zero_grad()
        
        Y = self.model(x, orthonorm_step=False)
        if self.siamese_net is not None:
            self.siamese_net.eval()
            with torch.no_grad():
                x = self.siamese_net.forward_once(x)

        W = get_affinity_matrix(x, self.dataset).to(self.device)

        loss = self.loss_function(W, Y)
        loss.backward()
        self.optimizer.step()
        return loss.item() 

    def valid_step(self, x, y):
        with torch.no_grad():
            Y = self.model(x, orthonorm_step=False)
            
        if self.siamese_net is not None:
            self.siamese_net.eval()
            with torch.no_grad():
                x = self.siamese_net.forward_once(x)

        W = get_affinity_matrix(x, self.dataset).to(self.device)
        
        # if self.counter % 10 == 0:
            # plot_laplacian_eigenvectors(Y, y)


        loss = self.loss_function(W, Y)
        return loss.item()

    def train(self, train_loader, valid_loader):
        print("Training SpectralNet: ")
        epochs = SPECTRAL_HYPERPARAMS[self.dataset]["num_epochs"]
        orthonorm = True

        for epoch in range(1, epochs + 1):
            epoch_loss = 0.0
            batches_len = 0.0
            for batch_x, _ in train_loader:
                batch_x = batch_x.to(device=self.device)
                batch_x = batch_x.view(batch_x.size(0), -1)

                if orthonorm:
                    self.ortho_step(batch_x)
                else:
                    loss = self.train_step(batch_x) 
                    epoch_loss += loss
                    batches_len += 1
                
                orthonorm = not orthonorm

            epoch_loss = epoch_loss / batches_len
            validation_loss = self.validate(valid_loader)
                
            self.scheduler.step(validation_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            if current_lr <= 1e-8: break
            print('SpectralNet learning rate = %.7f' % current_lr) 
            print('Epoch {} of {}, Train Loss: {:.7f} | Validation Loss: {:.7f}'
                .format(epoch, epochs, epoch_loss, validation_loss))
            
        torch.save(self.model.state_dict(), f"./networks/weights/mnist_spectral.pth")
        print('Finished training SpectralNet')

    def validate(self, valid_loader):
        validate_loss = 0.0

        self.model.eval()
        with torch.no_grad():
            for batch_x, batch_y in valid_loader:
                batch_x = batch_x.to(device=self.device)
                batch_x = batch_x.view(batch_x.size(0), -1)
                loss = self.valid_step(batch_x, batch_y)
                validate_loss += loss

        self.counter += 1
        validate_loss = validate_loss / len(valid_loader)
        return validate_loss

Brief explanation:
The goal of this network is to approximate the eigenvectors of the Laplacian matrix that is obtained from an affinity matrix W.
In the training process, the network samples a batch of data points computes W from the batch, and minimizes the following loss:
Screen Shot 2022-11-22 at 22.25.35
where y_i are the outputs of the network.
There is also an orthogonalization layer (self.orthonorm_weights) that is used the make the output orthogonal.
I would be really happy if you have any idea to what can cause my problem!!

Did the authors explain how their experiments were performed, how many runs they used to estimate the accuracy, and what the stddev of the final metric was?
If not, I would guess a “good seed” might have been picked as well?

From their experiments it didn’t seem like they chose a specific seed.
They just mentioned that the accuracy was 0.971±0.001.
But I know that their implementation wasn’t stable at all - in some runs the program crushed because of the Cholesky factorization and the solution was to start with a smaller learning rate

I don’t see anything obviously wrong in your code.
One thing that you could check is the validation loss calculation. In your current code snippet you are accumulating the validate_loss using the batch loss calculated by valid_step and divide it by len(valid_loader) (so the number of batches in the DataLoader).
If the last batch is smaller this would add a small bias to the validation loss calculation, so you might want to multiply the batch loss with the batch size and divide by the number of samples.
Something like this should work:

    loss = self.valid_step(batch_x, batch_y)
    loss = loss * batch_x.size(0)
    validate_loss += loss

validate_loss = validate_loss / len(valid_loader.dataset)

In any case, you might want to try to reach out to the authors to check how stable their training was. Already reaching the target accuracy sounds valid.

First thank u very much for your helpful responses!
I’ll check the thing with the valid loss.
And you say that if in general I managed to achieve their results it’s enough even if it’s not for all of the seeds?

Take my response with a grain of salt, but note that nobody will be able to use “all of the seeds” for any kind of experiment. A rigorous study would use “a few” different seeds and report the mean+/-stddev. In the best case, the source code is also provided so that other researchers can reproduce the results and verify them. However, that’s not always the case in reality and sometimes even single results are reported.

You’ve also mentioned that:

which already points to some instability issues.

Got it
Again, thanks a lot!!