Well, the error is in the title. For some reason there is some operation that breaks the gradient, and i have no idea where it is
Here is the code:
Model: (Audio decoders and encoders are conv1d encoders that compress input 4 times, and are not the source of the problem, they were tested separately and they worked)
class AudioGenerator(nn.Module):
def __init__(self,bs):
super(AudioGenerator,self).__init__()
self.encoder = AudioEncoder()
self.decoder = AudioDecoder()
self.bs = bs #batch size
self.lstm = nn.LSTM(4000,4000, batch_first = True)
self.linear1 = nn.Linear(4000,4000)
self.reset_hidden(bs)
self.tanh = nn.Tanh()
self.relu = nn.LeakyReLU()
def forward(self,inp):
out1 = self.encoder(inp)
_,(self.hidden,self.state) = self.lstm(out1, (self.hidden,self.state))
out = self.hidden.clone()
out = torch.squeeze(out,0)
out = self.linear1(out)
out = self.tanh(out)
out = out.reshape(out.shape[0], 1,out.shape[1])
out = self.decoder(out)
return torch.squeeze(out,1)
def reset_hidden(self, bs = None):
if bs is None:
self.hidden = torch.zeros(1,self.bs,4000).to(DEVICE)
self.state = torch.zeros(1,self.bs,4000).to(DEVICE)
else:
self.hidden = torch.zeros(1,bs,4000).to(DEVICE)
self.state = torch.zeros(1,bs,4000).to(DEVICE)
def print_inline(message): #a helper function
sys.stdout.write("\r" + message)
Training loop: (dataset returns a tensor of shape (seq_length) where seq_length is 16000*num_seconds, data loader assembles them into shape (batch_size,1,seq_length), after that they get separated into seconds of length 16000, and processed)
losses = [[]]
epochs = 100
batchnum = 1
batches = len(train_loader)
try:
for epoch in range(0,epochs):
for seqs in train_loader:
seqs = seqs.to(DEVICE)
#print("Seqs shape is: {}".format(seqs.shape))
start = time.time()
seqs = torch.split(seqs,16000,1) #extract seconds from input sequence
#Reset hidden and grad
model.reset_hidden()
optimizer.zero_grad()
for i in range(len(seqs)-1): #edited mistake here, was compairing network output ot its input, fixed that
output = model(torch.unsqueeze(seqs[i],1))
output_loss = loss(output,seqs[i+1])
output_loss.backward(retain_graph=True)
optimizer.step()
losses[0].append(output_loss.item())
print_inline("Epoch: {}/{} Batch: {}/{} Loss: {:.4f}; time/batch:{:.4f} seconds".format(epoch,epochs,batchnum,batches,
output_loss.item(), time.time() - start))
if batchnum % 25 == 0:
print('New checkpoint!')
torch.save({'optimizer_state_dict':generator.state_dict(),
'model_state_dict':discriminator.state_dict()
},save_dir)
plot_losses(losses)
batchnum += 1
batchnum = 0
plot_losses(losses)
except KeyboardInterrupt: #if interrupted show statistics
plot_losses(losses)
plot_losses(losses)