Hi everyone:
I have created an custom function for an customized forward and backward process
When I recorded the time consumption of forward and backward process, I found there is alwasy an extra time consumption in backward process.
Is there anyone could tell me why there is an extra time consumption in backward process?
from scipy.linalg import solve as scipy_solve
from torch.autograd import gradcheck
from torch.autograd import Function
from scipy.sparse.linalg import spsolve
import torch
import datetime
from torch.nn import MSELoss
from torch.nn import L1Loss
class SparseDenseSolve(Function):
@staticmethod
def forward(ctx, A, b):
if A.ndim != 2 or (A.shape[0] != A.shape[1]):
raise ValueError("A should be a square 2D matrix.")
step_0 = datetime.datetime.now()
if A.is_sparse:
A_np = A.to_dense().data.numpy()
else:
A_np = A.data.numpy()
step_1 = datetime.datetime.now()
b_np = b.data.numpy()
step_2 = datetime.datetime.now()
x_np = scipy_solve(A_np, b_np)
step_3 = datetime.datetime.now()
x = torch.tensor(x_np, requires_grad=True)
step_4 = datetime.datetime.now()
ctx.save_for_backward(A, b, x)
step_5 = datetime.datetime.now()
print('======forward process=======')
print(' step_1: ', step_1 - step_0)
print(' step_2: ', step_2 - step_1)
print(' step_3: ', step_3 - step_2)
print(' step_4: ', step_4 - step_3)
print(' step_5: ', step_5 - step_4)
return x
@staticmethod
def backward(ctx, grad):
A, b, x = ctx.saved_tensors
step_0 = datetime.datetime.now()
gradb = SparseDenseSolve.apply(torch.transpose(A, 0, 1), grad)
step_1 = datetime.datetime.now()
gradA = -gradb @ torch.transpose(x, 0, 1)
step_2 = datetime.datetime.now()
print('======backward process=======')
print(' step_1: ', step_1 - step_0)
print(' step_2: ', step_2 - step_1)
if A.is_sparse:
gradA = gradA.to_sparse()
step_3 = datetime.datetime.now()
print(' step_3: ', step_3 - step_2)
return gradA, gradb
else:
return gradA, gradb
solve = SparseDenseSolve.apply
width = 64*64
A = torch.randn(width, width, requires_grad=True)
A_sparse = A.to_sparse()
b = torch.randn(width, 1)
forward1 = datetime.datetime.now()
solution_cusfunc = solve(A_sparse, b)
forward2 = datetime.datetime.now()
backward1 = datetime.datetime.now()
solution_cusfunc.backward(torch.randn(width, 1).float())
backward2 = datetime.datetime.now()
print('time consumption:')
print('forward process: ',forward2 - forward1)
print('backward process: ',backward2 - backward1)
print('done')
the printed result is following:
/home/hongjin/anaconda3/envs/new-python3.6/bin/python /home/hongjin/PycharmProjects/Year_2021/Colorization_libo/0714_customed_sparse_linear_system_pytorch/sparsedense_input_matrix_matrix.py
======forward process=======
step_1: 0:00:02.819386
step_2: 0:00:00.000005
step_3: 0:00:00.482890
step_4: 0:00:00.000136
step_5: 0:00:00.000005
======forward process=======
step_1: 0:00:04.902399
step_2: 0:00:00.000006
step_3: 0:00:00.457058
step_4: 0:00:00.000080
step_5: 0:00:00.000004
======backward process=======
step_1: 0:00:05.488562
step_2: 0:00:00.010019
step_3: 0:00:00.378122
time consumption:
forward process: 0:00:03.304161
backward process: 0:00:08.660980
done
Process finished with exit code 0