I am trying to compute second order derivatives using the following two methods inside a custom class:
def grad_log_target(self, log_target_val):
log_target_val.backward(create_graph=True)
return torch.cat([p.grad.view(-1) for p in self.parameters()])
def hess_log_target(self, grad_log_target_val):
n_params = self.num_params()
hess_log_target_val = []
for i in range(n_params):
deriv_i_wrt_grad = grad(grad_log_target_val[i], self.parameters(), retain_graph=True)
hess_log_target_val.append(torch.cat([h.view(-1) for h in deriv_i_wrt_grad]))
hess_log_target_val = torch.cat(hess_log_target_val, 0).reshape(n_params, n_params)
return hess_log_target_val
The code runs without error, but the computational cost increases for each iteration of my algorithm (I am trying to run a custom Monte Carlo algorithm, not optimization). Here are some reported times that demonstrate the increase in runtime per iteration:
Iteration 1, duration 0:00:00.255764
Iteration 2, duration 0:00:00.245305
Iteration 3, duration 0:00:00.245335
Iteration 4, duration 0:00:00.254027
Iteration 5, duration 0:00:00.256714
Iteration 6, duration 0:00:00.252428
Iteration 7, duration 0:00:00.255005
Iteration 8, duration 0:00:00.257299
Iteration 9, duration 0:00:00.267525
Iteration 10, duration 0:00:00.268863
Iteration 11, duration 0:00:00.280390
Iteration 12, duration 0:00:00.276839
Iteration 13, duration 0:00:00.282090
Iteration 14, duration 0:00:00.279811
Iteration 15, duration 0:00:00.279818
Iteration 16, duration 0:00:00.288001
Iteration 17, duration 0:00:00.285519
Iteration 18, duration 0:00:00.319292
Iteration 19, duration 0:00:00.303236
Iteration 20, duration 0:00:00.314423
Iteration 21, duration 0:00:00.290186
Iteration 22, duration 0:00:00.290931
Iteration 23, duration 0:00:00.296420
Iteration 24, duration 0:00:00.296681
Iteration 25, duration 0:00:00.303668
Iteration 26, duration 0:00:00.312694
Iteration 27, duration 0:00:00.311441
Iteration 28, duration 0:00:00.310313
Iteration 29, duration 0:00:00.314935
Iteration 30, duration 0:00:00.314377
Iteration 31, duration 0:00:00.317583
Iteration 32, duration 0:00:00.335503
Iteration 33, duration 0:00:00.322630
Iteration 34, duration 0:00:00.319622
Iteration 35, duration 0:00:00.343187
Iteration 36, duration 0:00:00.326392
Iteration 37, duration 0:00:00.340833
Iteration 38, duration 0:00:00.330154
Iteration 39, duration 0:00:00.336486
Iteration 40, duration 0:00:00.360137
Iteration 41, duration 0:00:00.347079
Iteration 42, duration 0:00:00.339313
Iteration 43, duration 0:00:00.342601
Iteration 44, duration 0:00:00.352304
Iteration 45, duration 0:00:00.349789
Iteration 46, duration 0:00:00.353537
Iteration 47, duration 0:00:00.367225
Iteration 48, duration 0:00:00.377087
Iteration 49, duration 0:00:00.362028
Iteration 50, duration 0:00:00.372563
Iteration 51, duration 0:00:00.363326
Iteration 52, duration 0:00:00.374069
Iteration 53, duration 0:00:00.371866
Iteration 54, duration 0:00:00.384171
Iteration 55, duration 0:00:00.380464
Iteration 56, duration 0:00:00.384370
Iteration 57, duration 0:00:00.384062
Iteration 58, duration 0:00:00.384333
Iteration 59, duration 0:00:00.392266
Iteration 60, duration 0:00:00.389869
Iteration 61, duration 0:00:00.398985
Iteration 62, duration 0:00:00.398490
Iteration 63, duration 0:00:00.414814
Iteration 64, duration 0:00:00.400672
Iteration 65, duration 0:00:00.396808
Iteration 66, duration 0:00:00.393161
Iteration 67, duration 0:00:00.409639
Iteration 68, duration 0:00:00.433340
Iteration 69, duration 0:00:00.414874
Iteration 70, duration 0:00:00.428305
Iteration 71, duration 0:00:00.446901
Iteration 72, duration 0:00:00.410948
Iteration 73, duration 0:00:00.419323
Iteration 74, duration 0:00:00.422695
Iteration 75, duration 0:00:00.434619
Iteration 76, duration 0:00:00.435668
Iteration 77, duration 0:00:00.452214
Iteration 78, duration 0:00:00.425546
Iteration 79, duration 0:00:00.443100
Iteration 80, duration 0:00:00.463992
Iteration 81, duration 0:00:00.439153
Iteration 82, duration 0:00:00.445664
Iteration 83, duration 0:00:00.447337
Iteration 84, duration 0:00:00.458129
Iteration 85, duration 0:00:00.452122
Iteration 86, duration 0:00:00.471280
Iteration 87, duration 0:00:00.451846
Iteration 88, duration 0:00:00.456373
Iteration 89, duration 0:00:00.458176
Iteration 90, duration 0:00:00.470319
Iteration 91, duration 0:00:00.463827
Iteration 92, duration 0:00:00.467097
Iteration 93, duration 0:00:00.480898
Iteration 94, duration 0:00:00.457989
Iteration 95, duration 0:00:00.497614
Iteration 96, duration 0:00:00.518428
Iteration 97, duration 0:00:00.476043
Iteration 98, duration 0:00:00.469286
Iteration 99, duration 0:00:00.499041
Iteration 100, duration 0:00:00.502108
Is this caused by memory leak because the graph is not freed and therefore tensors are carried over between iterations?