when runing this code I get nan loss which makes no since.
i ran identical code with the stock wadam and it worked fine
whats even weirder is that even the first loss is nan
class Adam():
def init(self,parametrs,lr=0.01,betas=(0.9,0.999),eps=1e-8,decay= 0.):
self.groups=[{‘params’:p,
‘mom1’:torch.zeros(p.shape).to(p.device),
‘mom2’:torch.ones(p.shape).to(p.device),
‘max_mom2’:torch.ones(p.shape).to(p.device)} for p in parametrs]
self.lr=lr
self.betas=betas
self.eps=eps
self.decay=decay
self.step=0
def update(self):
self.step+=1
with torch.no_grad():
for p in self.groups:
grad=p[‘params’].grad
p['mom1']=grad*(1-self.betas[0])+p['mom1']*self.betas[0]
p['mom2']=grad*(1-self.betas[1])+p['mom2']*self.betas[1]
#p['max_mom2']=torch.maximum(p['mom2'],p['max_mom2'])
m=p['mom1']/(1-self.betas[0]**self.step)
v=p['mom2']/(1-self.betas[1]**self.step)
#v=p['max_mom2']/(1-self.betas[1]**self.step)
v=torch.sqrt()
p['params'].data-=self.lr*m/(v+self.eps)
p['params'].data-=self.decay*p['params']
opt=Adam(model2.parameters())