I am trying to train my GRU network but it is overfitting. After around 20 epoch train accuracy and test accuracy differs a lot, while train reaches ~90%, test is only 67% at 91 epoch.
I tried to reduce the overfit by adding a dropout laye but it doesn’t work. And I have 40091 train data and 16487 test data.
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bi_flag):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, bidirectional=bi_flag, batch_first=True)
self.num_layers = num_layers
def forward(self, input_tensor, seq_len):
encoder_hidden = torch.Tensor().to(device)
for it in range(max(seq_len)):
if it == 0:
enout_tmp, hidden_tmp = self.gru(input_tensor[:, it:it+1, :])
else:
enout_tmp, hidden_tmp = self.gru(input_tensor[:, it:it+1, :], hidden_tmp)
encoder_hidden = torch.cat((encoder_hidden, enout_tmp),1)
hidden = torch.empty((1, len(seq_len), encoder_hidden.shape[-1])).to(device)
count = 0
for ith_len in seq_len:
hidden[0, count, :] = encoder_hidden[count, ith_len - 1, :]
count += 1
return hidden
class seq2seq(nn.Module):
def __init__(self, en_input_size, en_hidden_size, output_size, batch_size,
en_num_layers=3, de_num_layers=1,
fix_state=False, fix_weight=False, teacher_force=False, bi_flag = False, negative_r = None, reverse_flag = False):
super(seq2seq, self).__init__()
num_class = 60
if bi_flag:
decoder_num = 2
else:
decoder_num = 1
self.batch_size = batch_size
self.en_num_layers = en_num_layers
self.encoder = EncoderRNN(en_input_size, en_hidden_size, en_num_layers, bi_flag).to(device)
self.en_input_size = en_input_size
self.teacher_force = teacher_force
self.dropout = nn.Dropout(0.5)
self.linear = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(inplace=True)
)
self.fcn = nn.Linear(256, num_class)
def forward(self, input_tensor, seq_len, index = None, cluster_result=None):
self.batch_size = len(seq_len)
encoder_hidden = self.encoder(input_tensor, seq_len)
encoder_hidden = self.dropout(encoder_hidden)
middle = self.linear(encoder_hidden)
x = self.fcn(middle)
x = x.view(x.size(1), -1)
return x