Getting correct prediction for only one class

Hi,
I am trying to build a classification model of MNIST data by using newton method of optimization.
After training my model. I am getting right prediction only for class 0. In test data there is 980 nos of O class samples. My model is giving test accuracy 9.8 % only. I checked pred=model(inputs). Found It is only correctly predicting for class O samples in test data.

My approach:

  1. Resnet18 with random weights is used

  2. gradients are calculated with the help of
    ‘’'env_grads = torch.autograd.grad(loss, cnn.parameters(), retain_graph=True, create_graph=True)“”"

  3. Hessian has been calculated with the help of "h_col=torch.autograd.grad(env_grads, cnn.parameters(), retain_graph=True, create_graph=False)

  4. Training process has been initialized with following script
    for batch, (X, y) in enumerate(train_loader):
    with torch.set_grad_enabled(True):
    cnn.zero_grad()
    cnn.train()
    wt1= torch.cat([gi.data.view(-1) for gi in cnn.parameters()]).view(-1, 1)
    X = X.to(device)
    y = y.to(device)
    pred=cnn(X)
    loss=loss_fn(pred, y)

Please help me to find out the possible reasons behind this issue.

I run the code with SGD in the same setting. It is performing well.

I checked the data distribution. there is no imbalance of data

Hi Mrinmay!

Something is fishy here. Resnet18 has about 11 million individual
parameters. The full Hessian will therefore require something like
500 terabytes of storage – not realistic.

It’s not clear what you are actually doing.

As an aside, instead of focusing on your predictions and accuracy,
when you are using a novel optimization method, you should start
by performing a single optimization step and check whether your
objective function actually goes down. If your optimization algorithm
isn’t reducing your objective function, something isn’t working correctly,
and you should fix that first before looking at higher-level performance
metrics.

Best.

K. Frank

Hi Frank,
I am not using full hessian. I approximated it.