Windows fatal exception: stack overflow while using pytorch for computing

Here is my code
My pytorch is 2.0.0+CPU

# -*- coding: utf-8 -*-
"""

"""


import torch
from torch.func import jacfwd
import numpy as np
import matplotlib.pyplot as plt
from timeit import default_timer as timer
#from functorch import jacfwd
#from functorch import jacrev
#from torch.autograd.functional import jacobian
from scipy.sparse.linalg import gmres


def get_capillary(swnew):
    
    capillary = torch.div(1, swnew)
    capillary[capillary == float("Inf")] = 200
    capillary = 2*swnew
    capillary = swnew
    
    return capillary


def get_relaperm (slnew):
    
    k0r1=0.6
    L1=1.8
    L2=1.8
    E1=2.1
    E2=2.1
    T1=2.3
    T2=2.3
    
    krl=torch.div(k0r1*slnew**L1,slnew**L2+E1*(1-slnew)**T1)
    krg=torch.div((1-slnew)**L2,(1-slnew)**L2+E2*slnew**T2)
    
    return krl, krg


def get_relaperm_classic (swnew):
    
    s_w_r = 0.2
    s_o_r = 0.2
    
    nw = 2
    no = 2
    
    krw_ = 1
    kro_ = 1
    
    krw = krw_*((swnew-s_w_r)/(1-s_w_r))**nw
    kro = kro_*(((1-swnew)-s_o_r)/(1-s_o_r))**no
        
    return krw, kro

def get_residual (unknown):      
    
    residual = torch.zeros(Np*Nx*Ny*Nz, requires_grad=False, dtype=torch.float64)
    
    pre_o = unknown[::2]
    pre_w = pre_o
    sat_w = unknown[1::2]
    sat_o = 1 - sat_w 
    
    pre_o_old = unknownold[::2]
    pre_w_old = pre_o_old
    
    sat_w_old = unknownold[1::2]
    sat_o_old = 1 - sat_w_old 
    
    
    poro      = poroini*(1+c_r*(pre_o-p_ref)+0.5*(c_r*(pre_o-p_ref))**2)
    poroold   = poroini*(1+c_r*(pre_o_old-p_ref)+0.5*(c_r*(pre_o_old-p_ref))**2)
    
    Bo        = Bo_ref/((1+c_o*(pre_o-p_ref)+0.5*(c_o*(pre_o-p_ref))**2))
    Boold     = Bo_ref/((1+c_o*(pre_o_old-p_ref)+0.5*(c_o*(pre_o_old-p_ref))**2))

    
    Bw        = Bw_ref/((1+c_w*(pre_w-p_ref)+0.5*(c_w*(pre_w-p_ref))**2))
    Bwold     = Bw_ref/((1+c_w*(pre_w_old-p_ref)+0.5*(c_w*(pre_w_old-p_ref))**2))

    
    miu_o     = miu_o_ref*(((1+c_o*(pre_o-p_ref)+0.5*(c_o*(pre_o-p_ref))**2))/((1+(c_o-upsilon_o)*(pre_o-p_ref)+0.5*((c_o-upsilon_o)*(pre_o-p_ref))**2)))
    miu_w     = miu_w_ref*(((1+c_w*(pre_w-p_ref)+0.5*(c_w*(pre_w-p_ref))**2))/((1+(c_w-upsilon_w)*(pre_w-p_ref)+0.5*((c_w-upsilon_w)*(pre_w-p_ref))**2)))

    
    
    Accumulation_o = (1/C1)*(sat_o*poro/Bo - sat_o_old*poroold/Boold)*vol
    Accumulation_w = (1/C1)*(sat_w*poro/Bw - sat_w_old*poroold/Bwold)*vol
        
    residual[::2]   += Accumulation_o
    residual[1::2]  += Accumulation_w 
     
    
    kro       = get_relaperm (sat_w)[1]
    krw       = get_relaperm (sat_w)[0]

    
    
    mobi_o = torch.div(kro, (Bo*miu_o))
    mobi_w = torch.div(krw, (Bw*miu_w))
    
    
    oterm = mobi_o/(mobi_o+mobi_w)*qp*dt
    wterm = mobi_w/(mobi_o+mobi_w)*qp*dt+qi*dt
    
    
    residual[::2]   += oterm
    residual[1::2]  += wterm 
    



    '''
    capillary = get_capillary(sat_l)
    
    pre_l     = pre_g - capillary
    pre_l     = pre_g
    
    
    gravity_g = rho_g*g
    gravity_l = rho_l*g
    '''

    
    
    for i in connection_index:
        phi_pre_o = pre_o[connection_a[i]] - pre_o[connection_b[i]]
        phi_pre_w = pre_w[connection_a[i]] - pre_w[connection_b[i]]
        
        up_o = connection_a[i] if phi_pre_o >= 0 else connection_b[i]
        up_w = connection_a[i] if phi_pre_w >= 0 else connection_b[i]

        K_h       = 2*K[connection_a[i]]*K[connection_b[i]] / (K[connection_a[i]] + K[connection_b[i]])
        Tran_h    = K_h*A[i]/d[i]
     
        
        Tran_o    = Tran_h*kro[up_o]/miu_o[up_o]/Bo[up_o]        
        Tran_w    = Tran_h*krw[up_w]/miu_w[up_w]/Bw[up_w] 
        
        
        flux_o    = Tran_o*phi_pre_o
        flux_w    = Tran_w*phi_pre_w
        
        
        ind_a     = 2*connection_a[i]
        ind_b     = 2*connection_b[i]
        
        
        residual[ind_a] += C2*dt*flux_o
        residual[ind_b] -= C2*dt*flux_o
        
        ind_a     += 1
        ind_b     += 1
        

        residual[ind_a] += C2*dt*flux_w
        residual[ind_b] -= C2*dt*flux_w
          

    
    return residual


if __name__ == '__main__':
    
    C1 = 5.615
    C2 = 1.12712e-3
    
    
    Nx = 25
    Ny = 25
    Nz = 1
    
    Lx = 500
    Ly = 100
    Lz = 100
    
    
    p_ref = 14.7
    
    dx = Lx/Nx
    dy = Ly/Ny
    dz = Lz/Nz
    
    vol = dx*dy*dz;
     
    
    K   = 100*torch.ones(Nx*Ny*Nz, requires_grad=True, dtype=torch.float64)
   
    
    dt   = 0.1
    tf   = 1
    time = 0
    alpha_chop = 0.5
    alpha_grow = 2
    dt_min = 0.1
    Max_iter = 10
    Tol_resi = 1e-7

    
    
    g = 9.80665 
    
    Np= 2



    c_r = 1e-6
    c_o = 1e-4 #oil 
    c_w = 1e-6 #water
    p_ref  = 14.7
    Bo_ref = 1
    Bw_ref = 1
    miu_o_ref = 1
    miu_w_ref = 1
    rho_o_ref = 53
    rho_w_ref = 64
    poroini   = 0.2
    
    upsilon_o = 0
    upsilon_w = 0
    
    
    
    #Well
    qi = torch.zeros(Nx, Ny)
    qp = torch.zeros(Nx, Ny)
    
    
    #wilocax = torch.randint(0, Nx, (1,))
    #wilocay = torch.randint(0, Ny, (1,))

    wilocax = 0*torch.ones(1, dtype=torch.int64)
    wilocay = 0*torch.ones(1, dtype=torch.int64)

    
    qi[wilocax, wilocay] = -100
    
    
    #wplocax = torch.randint(0, Nx, (1,))
    #wplocay = torch.randint(0, Ny, (1,))

    wplocax = 4*torch.ones(1,dtype=torch.int64)
    wplocay = 0*torch.ones(1,dtype=torch.int64)

    
    qp[wplocax, wplocay] = 10
    
    qi=qi.reshape([-1,])
    qp=qp.reshape([-1,])
    
    


    #grids        = torch.arange(0, Nx*Ny*Nz, requires_grad=False, dtype=torch.int32)
    #grids        = torch.reshape(grids,(Nx,Ny)).t()
        

    connection_x = torch.arange(0, (Nx-1)*Ny*Nz, requires_grad=False, dtype=torch.int32)
    connection_y = torch.arange(0, Nx*(Ny-1)*Nz, requires_grad=False, dtype=torch.int32)
    connection_z = torch.arange(0, Nx*Ny*(Nz-1), requires_grad=False, dtype=torch.int32) 
    connection   = torch.cat((connection_x, connection_y, connection_z), 0)

    
    A_x  = dy*dz*torch.ones(connection_x.size(dim=0), requires_grad=False, dtype=torch.int32)
    A_y  = dx*dz*torch.ones(connection_y.size(dim=0), requires_grad=False, dtype=torch.int32)
    A_z  = dx*dy*torch.ones(connection_z.size(dim=0), requires_grad=False, dtype=torch.int32)
    A    = torch.cat((A_x, A_y, A_z), 0)
     
   
    d_x  = dx*torch.ones(connection_x.size(dim=0), requires_grad=False, dtype=torch.int32)
    d_y  = dy*torch.ones(connection_y.size(dim=0), requires_grad=False, dtype=torch.int32)
    d_z  = dz*torch.ones(connection_z.size(dim=0), requires_grad=False, dtype=torch.int32)
    d    = torch.cat((d_x, d_y, d_z), 0)
    

    connection_x_index = torch.arange(0, (Nx-1)*Ny*Nz, requires_grad=False, dtype=torch.int32)
    connection_y_index = torch.arange((Nx-1)*Ny*Nz, (Nx-1)*Ny*Nz+Nx*(Ny-1)*Nz, requires_grad=False, dtype=torch.int32)
    connection_z_index = torch.arange((Nx-1)*Ny*Nz+Nx*(Ny-1)*Nz, (Nx-1)*Ny*Nz+Nx*(Ny-1)*Nz+Nx*Ny*(Nz-1), requires_grad=False, dtype=torch.int32)
    connection_index  = torch.cat((connection_x_index, connection_y_index, connection_z_index), 0)

   
    connection_xa = connection_x + torch.div(connection_x, Nx-1, rounding_mode='trunc')
    connection_xb = connection_xa + 1
    connection_ya = connection_y + Nx*torch.div(connection_y, Nx*(Ny-1), rounding_mode='trunc')
    connection_yb = connection_ya + Nx
    connection_za = connection_z
    connection_zb = connection_za + Nx*Ny

    
    connection_a  = torch.cat((connection_xa, connection_ya, connection_za), 0)
    connection_b  = torch.cat((connection_xb, connection_yb, connection_zb), 0)
   
    
    # IC
    swnew       = 0.3*torch.ones(Nx*Ny*Nz, requires_grad=True, dtype=torch.float64)   
    ponew       = 6000*torch.ones(Nx*Ny*Nz, requires_grad=True, dtype=torch.float64)
    pwnew       = 6000*torch.ones(Nx*Ny*Nz, requires_grad=True, dtype=torch.float64)

    pc          = get_capillary(swnew)
    unknown     = torch.ravel(torch.column_stack((ponew, swnew)))
    unknownold  = unknown.detach().clone() 
    
    ntimestep  = 1;
    
    print('PyTorch version is '+torch.__version__)
    if torch.cuda.is_available():
        print("PyTorch is installed with GPU support.")
        print("Number of available GPUs:", torch.cuda.device_count())
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        print("PyTorch is installed without GPU support.")
    
    
    while abs(time - tf) > 1e-8:
        niter  = 0
        start  = timer()
        r      = get_residual(unknown)
        end    = timer()
        print('get_residual timing:', end - start)
        start  = timer()
        J      = jacfwd(get_residual)(unknown)
        #J      = jacobian(get_residual, unknown)
        end    = timer()
        print('Jacfwd timing: ', end - start)
        #cr     = r.detach().numpy()
        #cJ     = J.detach().numpy()
        #x, exitCode = gmres(Jnew, -rnew)

        while True:
            update       = torch.linalg.solve(J, r)
            niter        = niter+1;
            unknown     -= update
            #XiaoYuLing   = torch.where(unknown[1::2] < 0)
            #unknown[1::2][XiaoYuLing] = 0
            r            = get_residual(unknown)

            
            if (torch.linalg.vector_norm(r, 2) <= Tol_resi):
                is_coverged = 1                
                print (' ')
                print ('****************************************************************************************')
                print ('From time '+str(time)+' to time '+str(time+dt))
                print ('Timestep '+str(ntimestep)+' convergers, here is the report:')
                print ('2-Norm of the residual system: '+ str(torch.linalg.vector_norm(r, 2).detach().numpy()))
                print ('Number of iterations: '+ str(niter))
                print ('****************************************************************************************')
                print (' ')
                ntimestep   = ntimestep + 1 
                unknownold  = unknown.detach().clone() 
                
                plt.plot(unknown[::2].detach().numpy())
                plt.ylabel('Pressure')
                #plt.show()                
            else:
                is_coverged = 0
                J      = jacfwd(get_residual)(unknown)
            
            if ((niter > Max_iter) or (is_coverged)):
                break

        if (not is_coverged):
            #dt    *= alpha_chop 
            dt     = dt if dt >= dt_min else dt_min
        else:
            time  += dt
            #dt    *= alpha_grow
            dt     = (tf - time) if (time + dt) > tf else dt

If I increase the Nx, Ny, Nz here is the error
Windows fatal exception: stack overflow

Main thread:
Current thread 0x00002934 (most recent call first):
File “d:\dlsim\ai\25x25\main_ow.py”, line 328 in
File “C:\Users\wec8371\Anaconda3\envs\AISim\lib\site-packages\spyder_kernels\py3compat.py”, line 356 in compat_exec
File “C:\Users\wec8371\Anaconda3\envs\AISim\lib\site-packages\spyder_kernels\customize\spydercustomize.py”, line 469 in exec_code
File “C:\Users\wec8371\Anaconda3\envs\AISim\lib\site-packages\spyder_kernels\customize\spydercustomize.py”, line 611 in _exec_file
File “C:\Users\wec8371\Anaconda3\envs\AISim\lib\site-packages\spyder_kernels\customize\spydercustomize.py”, line 524 in runfile
File “C:\Users\wec8371\AppData\Local\Temp\ipykernel_8860\3764171792.py”, line 1 in

Restarting kernel…

If Nx, Ny are small, like 10, the code works well