when runing this code I get nan loss which makes no since.

i ran identical code with the stock wadam and it worked fine

whats even weirder is that even the first loss is nan

class Adam():

def **init**(self,parametrs,lr=0.01,betas=(0.9,0.999),eps=1e-8,decay= 0.):

self.groups=[{‘params’:p,

‘mom1’:torch.zeros(p.shape).to(p.device),

‘mom2’:torch.ones(p.shape).to(p.device),

‘max_mom2’:torch.ones(p.shape).to(p.device)} for p in parametrs]

self.lr=lr

self.betas=betas

self.eps=eps

self.decay=decay

self.step=0

def update(self):

self.step+=1

with torch.no_grad():

for p in self.groups:

grad=p[‘params’].grad

```
p['mom1']=grad*(1-self.betas[0])+p['mom1']*self.betas[0]
p['mom2']=grad*(1-self.betas[1])+p['mom2']*self.betas[1]
#p['max_mom2']=torch.maximum(p['mom2'],p['max_mom2'])
m=p['mom1']/(1-self.betas[0]**self.step)
v=p['mom2']/(1-self.betas[1]**self.step)
#v=p['max_mom2']/(1-self.betas[1]**self.step)
v=torch.sqrt()
p['params'].data-=self.lr*m/(v+self.eps)
p['params'].data-=self.decay*p['params']
```

opt=Adam(model2.parameters())