How does one use Hermite polynomials in pytorch with Stochastic Gradient Descent (SGD)?

I was trying to train a simple polynomial linear model with pytorch using Hermite polynomials since they seem to have a better conditioned Hessian.

To do that I decided to use the hermvander since it gives the Vandermonde matrix with each entry being a Hermite term. To do that I just made my feature vectors be the outpute of hermvander:

Kern_train = hermvander(X_train,Degree_mdl)

however, when I proceeded to train I get NaN all the time. I suspected it could have been a step size issue but I decided to use the step size suggested by this question that already has my example working in R, so there is no need to search for a step size I thought. However, when I tried it it does not work.

Any pytorch expert has any idea whats going on?

Totally self contained code:

import numpy as np
from numpy.polynomial.hermite import hermvander

import random

import torch
from torch.autograd import Variable

def vectors_dims_dont_match(Y,Y_):
    '''
    Checks that vector Y and Y_ have the same dimensions. If they don't
    then there might be an error that could be caused due to wrong broadcasting.
    '''
    DY = tuple( Y.size() )
    DY_ = tuple( Y_.size() )
    if len(DY) != len(DY_):
        return True
    for i in range(len(DY)):
        if DY[i] != DY_[i]:
            return True
    return False

def index_batch(X,batch_indices,dtype):
    '''
    returns the batch indexed/sliced batch
    '''
    if len(X.shape) == 1: # i.e. dimension (M,) just a vector
        batch_xs = torch.FloatTensor(X[batch_indices]).type(dtype)
    else:
        batch_xs = torch.FloatTensor(X[batch_indices,:]).type(dtype)
    return batch_xs

def get_batch2(X,Y,M,dtype):
    '''
    get batch for pytorch model
    '''
    # TODO fix and make it nicer, there is pytorch forum question
    X,Y = X.data.numpy(), Y.data.numpy()
    N = len(Y)
    valid_indices = np.array( range(N) )
    batch_indices = np.random.choice(valid_indices,size=M,replace=False)
    batch_xs = index_batch(X,batch_indices,dtype)
    batch_ys = index_batch(Y,batch_indices,dtype)
    return Variable(batch_xs, requires_grad=False), Variable(batch_ys, requires_grad=False)

def get_sequential_lifted_mdl(nb_monomials,D_out, bias=False):
    return torch.nn.Sequential(torch.nn.Linear(nb_monomials,D_out,bias=bias))

def train_SGD(mdl, M,eta,nb_iter,logging_freq ,dtype, X_train,Y_train):
    ##
    #pdb.set_trace()
    N_train,_ = tuple( X_train.size() )
    #print(N_train)
    for i in range(1,nb_iter+1):
        # Forward pass: compute predicted Y using operations on Variables
        batch_xs, batch_ys = get_batch2(X_train,Y_train,M,dtype) # [M, D], [M, 1]
        ## FORWARD PASS
        y_pred = mdl.forward(batch_xs)
        ## Check vectors have same dimension
        if vectors_dims_dont_match(batch_ys,y_pred):
            raise ValueError('You vectors don\'t have matching dimensions. It will lead to errors.')
        ## LOSS + Regularization
        batch_loss = (1/M)*(y_pred - batch_ys).pow(2).sum()
        ## BACKARD PASS
        batch_loss.backward() # Use autograd to compute the backward pass. Now w will have gradients
        ## SGD update
        for W in mdl.parameters():
            delta = eta(i)*W.grad.data
            W.data.copy_(W.data - delta)
        ## train stats
        if i % (nb_iter/10) == 0 or i == 0:
            #X_train_, Y_train_ = Variable(X_train), Variable(Y_train)
            X_train_, Y_train_ = X_train, Y_train
            current_train_loss = (1/N_train)*(mdl.forward(X_train_) - Y_train_).pow(2).sum().data.numpy()
            print('\n-------------')
            print(f'i = {i}, current_train_loss = {current_train_loss}\n')
            print(f'eta*W.grad.data = {eta*W.grad.data}')
            print(f'W.grad.data = {W.grad.data}')
        ## Manually zero the gradients after updating weights
        mdl.zero_grad()
    final_sgd_error = current_train_loss
    return final_sgd_error
##
D0=1
logging_freq = 100
#dtype = torch.cuda.FloatTensor
dtype = torch.FloatTensor
## SGD params
M = 5
eta = 0.1
eta = lambda i: eta/(i**0.6)
nb_iter = 500*10
##
lb,ub = 0,1
freq_sin = 4 # 2.3
f_target = lambda x: np.sin(2*np.pi*freq_sin*x)
N_train = 10
X_train = np.linspace(lb,ub,N_train)
Y_train = f_target(X_train).reshape(N_train,1)
x_horizontal = np.linspace(lb,ub,1000).reshape(1000,1)
## degree of mdl
Degree_mdl = N_train-1
## Hermite
Kern_train = hermvander(X_train,Degree_mdl)
Kern_train = Kern_train.reshape(N_train,Kern_train.shape[2])
##
Kern_train_pinv = np.linalg.pinv( Kern_train )
c_pinv = np.dot(Kern_train_pinv, Y_train)
##
condition_number_hessian = np.linalg.cond(Kern_train)
## linear mdl to train with SGD
nb_terms = c_pinv.shape[0]
mdl_sgd = get_sequential_lifted_mdl(nb_monomials=nb_terms,D_out=1, bias=False)
mdl_sgd[0].weight.data.normal_(mean=0,std=0.001)
mdl_sgd[0].weight.data.fill_(0)
## Make polynomial Kernel
Kern_train_pt, Y_train_pt = Variable(torch.FloatTensor(Kern_train).type(dtype), requires_grad=False), Variable(torch.FloatTensor(Y_train).type(dtype), requires_grad=False)
final_sgd_error = train_SGD(mdl_sgd, M,eta,nb_iter,logging_freq ,dtype, Kern_train_pt,Y_train_pt)
## PRINT ERRORS
from plotting_utils import *

train_error_pinv = (1/N_train)*(np.linalg.norm(Y_train-np.dot(Kern_train,c_pinv))**2)
print('\n-----------------')
print(f'N_train={N_train}')
print(f'train_error_pinv = {train_error_pinv}')
print(f'final_sgd_error = {final_sgd_error}')

print(f'condition_number_hessian = {condition_number_hessian}')
print('\a')