How does one obtain gradients as data efficiently?

the whole code just for reference:

'''
This script is written to test:
 1)what happens if we differentiate through pytorch's backward pass and to see
    if its consistent with the doing a second derivative
'''
import torch
import torch.nn as nn

from grad_test_sympy import symbolic_test

import sys

from pdb import set_trace as st

def test_does_backward_backward_match_second_derivative(experiment_type,grad_computation_type='dl_dw_zero_grad'):
    '''
     1)what happens if we differentiate through pytorch's backward pass and to see
        if its consistent with the doing a second derivative

    :param str experiment_type:
        - grad_constant: this one is to make sure the sympy code and pytorch matches.
            when g is constant there is no second call of backwards so its just
            a sanity test to make sure things are working.
        - grad_constant_but_use_backward_on_loss: this one calls
            backward twice (one on the loss, second on J that depends on loss)
            and the expected result for me it's to match the sympy code.
            It doesn't match
        -
    '''
    ## variable declaration
    w = torch.tensor([2.0], requires_grad=True)

    x = torch.tensor([3.0], requires_grad=False)
    y = torch.tensor([4.0], requires_grad=False)

    x2 = torch.tensor([5.0], requires_grad=False)
    y2 = torch.tensor([6.0], requires_grad=False)

    if experiment_type == 'grad_constant':
        ## computes backard pass on J (i.e. dJ_dw) assuming g is constant
        # compute g
        if grad_computation_type == 'g_hard_coded':
            grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
            g = grad.item()
        elif grad_computation_type == 'dl_dw_zero_grad':
            # compute the backward pass dl_dw
            loss = (w*x-y)**2
            loss.backward()
            g = w.grad.item()
            # zero out the gradients on the orginal tensor
            w.grad.zero_() # note this operation does NOT make g zero.
        else:
            ww = torch.randn(w.size(), requires_grad=True)
            ww.data = w.clone()
            loss = (ww*x-y)**2
            loss.backward()
            grad = ww.grad
            g = grad.item()
        assert( (2*(w*x-y)*x).item() == grad.item() )
        # compute w_new
        w_new = w - (g+w**2) * g
        # compute final loss J
        J = (w_new + x2 + y2)**2
        # computes derivative of J
        J.backward()
        #dw_new_dw = w_new.grad.item()
        dJ_dw = w.grad.item()
    elif experiment_type == 'grad_constant_but_use_backward_on_loss':
        ## computes backard pass on J (i.e. dJ_dw) but g is was backwarded passed already
        # compute g
        loss = (w*x-y)**2
        loss.backward()
        #g = w.grad.item() # dl_dw
        g = w.grad # dl_dw
        g.requires_grad = True
        print(g)
        # compute w_new
        w_new = w - (g+w**2) * g
        # compute final loss J
        J = (w_new + x2 + y2)**2
        # computes derivative of J
        J.backward()
        #dw_new_dw = w_new.grad.item()
        dJ_dw = w.grad.item()
    elif experiment_type == 'grad_analytic_only_backward_on_J':
        ## computes backard pass on J (i.e. dJ_dw) assuming g is constant
        # compute g
        grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
        g = grad
        # compute w_new
        w_new = w - (g+w**2) * g
        # compute final loss J
        J = (w_new + x2 + y2)**2
        # computes derivative of J
        J.backward()
        #dw_new_dw = w_new.grad.item()
        dJ_dw = w.grad.item()
    else:
        raise ValueError(f'This: experiment_type={experiment_type}, is not a valid test')
    ##
    print('---- test results ----')
    print(f'experiment_type = {experiment_type}')
    print('-- Pytorch results')
    print(f'g = {g}')
    #print(f'dw_new_dw = {dw_new_dw}')
    print(f'dJ_dw = {dJ_dw}')
    print('-- Sympy results')
    g, dw_new_dw, dJ_dw = symbolic_test(experiment_type)
    print(f'g_SYMPY = {g}')
    #print(f'dw_new_dw_SYMPY = {dw_new_dw}')
    print(f'dJ_dw_SYMPY = {dJ_dw}')

if __name__ == '__main__':
    test_does_backward_backward_match_second_derivative(experiment_type='grad_constant')
    #test_does_backward_backward_match_second_derivative(experiment_type='grad_constant_but_use_backward_on_loss')
    #test_does_backward_backward_match_second_derivative(experiment_type='grad_analytic_only_backward_on_J')