How to calculate partial columns hessian

I am using following code to generate hessian with 10 columns only, not full hessian (dimension is d x m) and a m x m matrix M which should be symmetric.

import torch
import warnings
import torchvision
import torchvision.models as models
from torchvision import transforms
import torch.nn as nn
import numpy as np
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) ), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.MNIST(root=‘data’, train=True, download=True, transform=transform), batch_size=512, shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root=‘data’, train=False, download=True, transform=transform)
testloader =, batch_size=128, shuffle=True, num_workers=2)
cnn = models.resnet18(pretrained=False)
classifier = nn.Linear(in_features=512, out_features=10, bias=True)

loss_fn = nn.CrossEntropyLoss()
p=[ for gi in cnn.parameters()])
column_indx= np.random.choice(len(p), 20, replace= False)

for batch, (X, y) in enumerate(train_loader):
wt1=[ for gi in cnn.parameters()]).view(-1, 1)
X =
y =
loss=loss_fn(pred, y) #mean loss for one batch
l2_norm = sum(p.pow(2.0).sum() for p in cnn.parameters())
loss = loss + l2_lambda * l2_norm
env_grads = torch.autograd.grad(loss, cnn.parameters(), retain_graph=True, create_graph=True)
g=[gi.reshape(-1) for gi in env_grads])
C=torch.zeros(10, len(p)).to(device)
for i in range(10):
h_col=torch.autograd.grad(g[column_indx[i]], cnn.parameters(), retain_graph=True, create_graph=False)
C[i]=[gi.reshape(-1) for gi in h_col])

with torch.no_grad():
    M=C[: , column_indx]
C is the partial columns hessian (size d x m , where m is number of required columns) and M , which is formed by the intersection between those m columns of H.

If i am not wrong, hessian is always symmetric. Then M also should be symmetric. But i am not getting symmetric.
Please help me, how to get symmetric M.