Hey all,
I’m trying to implement a model someone coded up in Keras by using PyTorch. However, my implementations seems to be doing poorly compared to the Keras. I’ll list both codes:
Keras:
BATCH_SIZE = 50
MAX_EPOCHS = 200
model = Sequential()
model.add(LSTM(256, input_shape=(64, 49), activation=‘tanh’, return_sequences=True))
model.add(Dense(49))
model.add(Activation(‘softmax’))
model.compile(loss=‘categorical_crossentropy’, optimizer=Adam(0.01))
callbacks = []
early_stopping = EarlyStopping(monitor=‘loss’, min_delta=0.01, patience=10)
callbacks.append(early_stopping)
model.fit(training_data, label_data, batch_size=50, epochs=200, callbacks = callbacks)
PyTorch:
class my_model(nn.Module):
def init(self, hidden_sz, note_range, seq_len):
super(model,self).init()
self.Encoder = nn.LSTM(note_range, hidden_sz)
self.Decoder = nn.Sequential(nn.Linear(hidden_sz, note_range),nn.Softmax(dim=-1))
self.seq_len = seq_len
self.hidden_sz = hidden_sz
self.note_range = note_range
def train(self, training_data, label_data, lr_rate = 1e-2,epochs = 200, batch_sz = 128):
#Set useful constant.
seq_len = self.seq_len
#Set optimizer and loss function.
optimizer = optim.Adam(self.parameters(), lr = lr_rate)
for epoch in range(epochs):
N = training_data.shape[0]
perm = torch.randperm(N)
training_data = training_data[perm]
label_data = label_data[perm]
for i in range(N//batch_sz):
batch_train = training_data[i*batch_sz:(i+1)*batch_sz]
batch_label = label_data[i*batch_sz:(i+1)*batch_sz]
encd = self.Encoder(batch_train)[0]
encd = self.Decoder(encd)
loss = -torch.sum(batch_labeltorch.log(encd))/(batch_szseq_len) #Cross Entropy
#Zero gradient, calculate gradient, then gradient step.
optimizer.zero_grad()
loss.backward()
optimizer.step()
model = my_model(hidden_sz=256, note_range = 49, seq_len = 64)
model.train(training_data,label_data)
In both cases I use the same data. I used similar notation to highlight the similarities. But the first one has a cross entropy loss below 1 by the 100th epoch, but the pytorch one never even reaches there. Any ideas?
Also, some help with editing the code on the forums so it doesn’t look like a mess would be great!