I am currently trying to train a 3-layer LSTM for a classification task. The input sequence has variable length,so I padded every sequence with zero to the longest one within the minibatch and the padded label is set to -1 which will be ignore in the loss calculation. When I train LSTM with batch_size=1, it works well, the cross entropy loss decreases and the training classification accuracy increases. The problem is when I set batch_size >1, e.g. batch_size=8, the loss decreases while the accuracy do not increase. Could anyone help me to figure out why ?
Some related code is as follows:
class Model(nn.Module):
def __init__(self, args):
super(Model, self).__init__()
self.args = args
self.n_d = args.feadim
self.n_cell=args.hidnum
self.depth = args.depth
self.drop = nn.Dropout(args.dropout)
self.n_V = args.statenum
if args.lstm:
self.rnn = nn.LSTM(self.n_d, self.n_cell,
self.depth,
dropout = args.rnn_dropout,
batch_first = True
)
else:
pass
self.output_layer = nn.Linear(self.n_cell, self.n_V)
def forward(self, x, hidden,lens):
rnnout, hidden = self.rnn(x, hidden)
output = self.drop(rnnout)
output = output.view(-1, output.size(2))
output = self.output_layer(output)
return output, hidden
def train_model(epoch, model, train_reader):
model.train()
args = model.args
batch_size = args.batch_size
total_loss = 0.0
criterion = nn.CrossEntropyLoss(size_average=False,ignore_index=-1)
hidden = model.init_hidden(batch_size)
i=0
running_acc=0
total_frame=0
while True:
feat,label,length = train_reader.load_next_nstreams()
if length is None or label.shape[0]<args.batch_size:
break
else:
x, y = Variable(torch.from_numpy(feat)).cuda(), Variable(torch.from_numpy(label).long()).cuda()
hidden = model.init_hidden(batch_size)
hidden = (Variable(hidden[0].data), Variable(hidden[1].data)) if args.lstm else Variable(hidden.data)
model.zero_grad()
output, hidden = model(x, hidden,length)
assert x.size(0) == batch_size
loss = criterion(output, y.view(-1))
_,predict = torch.max(output,1)
correct = (predict == y).sum()
loss.backward()
total_loss += loss.data[0]
running_acc += correct.data[0]
total_frame += sum(length)
i+=1
if i%10 == 0:
sys.stdout.write(“time:{}, Epoch={},trbatch={},loss={:.4f},tracc={:.4f}\n”.format(datetime.now(),epoch,i,total_loss/total_frame,
running_acc*1.0/total_frame))
sys.stdout.flush()