Why does customized backward function has an extra time consumption?

Hi everyone:

I have created an custom function for an customized forward and backward process

When I recorded the time consumption of forward and backward process, I found there is alwasy an extra time consumption in backward process.
Is there anyone could tell me why there is an extra time consumption in backward process?

from scipy.linalg import solve as scipy_solve
from torch.autograd import gradcheck
from torch.autograd import Function
from scipy.sparse.linalg import spsolve
import torch
import datetime
from torch.nn import MSELoss
from torch.nn import L1Loss
class SparseDenseSolve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 2 or (A.shape[0] != A.shape[1]):
            raise ValueError("A should be a square 2D matrix.")
        step_0 = datetime.datetime.now()
        if A.is_sparse:
            A_np = A.to_dense().data.numpy()
        else:
            A_np = A.data.numpy()
        step_1 = datetime.datetime.now()

        b_np = b.data.numpy()
        step_2 = datetime.datetime.now()

        x_np = scipy_solve(A_np, b_np)
        step_3 = datetime.datetime.now()

        x = torch.tensor(x_np, requires_grad=True)
        step_4 = datetime.datetime.now()

        ctx.save_for_backward(A, b, x)
        step_5 = datetime.datetime.now()

        print('======forward process=======')
        print(' step_1: ', step_1 - step_0)
        print(' step_2: ', step_2 - step_1)
        print(' step_3: ', step_3 - step_2)
        print(' step_4: ', step_4 - step_3)
        print(' step_5: ', step_5 - step_4)

        return x

    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        step_0 = datetime.datetime.now()
        gradb = SparseDenseSolve.apply(torch.transpose(A, 0, 1), grad)
        step_1 = datetime.datetime.now()

        gradA = -gradb @ torch.transpose(x, 0, 1)
        step_2 = datetime.datetime.now()
        print('======backward process=======')
        print(' step_1: ', step_1 - step_0)
        print(' step_2: ', step_2 - step_1)
        if A.is_sparse:
            gradA = gradA.to_sparse()
            step_3 = datetime.datetime.now()
            print(' step_3: ', step_3 - step_2)
            return gradA, gradb
        else:
            return gradA, gradb


solve = SparseDenseSolve.apply

width = 64*64
A = torch.randn(width, width, requires_grad=True)
A_sparse = A.to_sparse()
b = torch.randn(width, 1)

forward1 = datetime.datetime.now()
solution_cusfunc = solve(A_sparse, b)
forward2 = datetime.datetime.now()

backward1 = datetime.datetime.now()
solution_cusfunc.backward(torch.randn(width, 1).float())
backward2 = datetime.datetime.now()

print('time consumption:')
print('forward process: ',forward2 - forward1)
print('backward process: ',backward2 - backward1)

print('done')

the printed result is following:

/home/hongjin/anaconda3/envs/new-python3.6/bin/python /home/hongjin/PycharmProjects/Year_2021/Colorization_libo/0714_customed_sparse_linear_system_pytorch/sparsedense_input_matrix_matrix.py
======forward process=======
 step_1:  0:00:02.819386
 step_2:  0:00:00.000005
 step_3:  0:00:00.482890
 step_4:  0:00:00.000136
 step_5:  0:00:00.000005
======forward process=======
 step_1:  0:00:04.902399
 step_2:  0:00:00.000006
 step_3:  0:00:00.457058
 step_4:  0:00:00.000080
 step_5:  0:00:00.000004
======backward process=======
 step_1:  0:00:05.488562
 step_2:  0:00:00.010019
 step_3:  0:00:00.378122
time consumption:
forward process:  0:00:03.304161
backward process:  0:00:08.660980
done

Process finished with exit code 0