Very small grad for parameters in PyTorch

Hi, I’m implementing ELMo model (paper + GRU architecture) using pytorch on sentiment analysis task (2 classes).
My problem is after training model for 3 epochs (almost takes 7 hours), parameters are almost constant, I mean parameters get update but grad value for every parameter is almost zero and parameters updates so slow.
After training model for almost 100 samples (just for test and long time for every epoch) I printed model output on trained samples (64 sentences) and you can see all of outputs are almost 0.61 or 0.62 (models output before applying sigmoid is almost zero):

[0.6190, 0.6177, 0.6218, 0.6209, 0.6216, 0.6177, 0.6218, 0.6248, 0.6187,
        0.6209, 0.6208, 0.6197, 0.6208, 0.6201, 0.6164, 0.6204, 0.6187, 0.6186,
        0.6172, 0.6227, 0.6180, 0.6176, 0.6177, 0.6189, 0.6167, 0.6162, 0.6204,
        0.6212, 0.6212, 0.6170, 0.6175, 0.6188, 0.6200, 0.6207, 0.6211, 0.6186,
        0.6171, 0.6190, 0.6171, 0.6215, 0.6204, 0.6166, 0.6169, 0.6189, 0.6192,
        0.6171, 0.6198, 0.6210, 0.6217, 0.6182, 0.6205, 0.6167, 0.6185, 0.6185,
        0.6247, 0.6201, 0.6183, 0.6172, 0.6248, 0.6156, 0.6187, 0.6221, 0.6184,
        0.6200]

mean grad value for first layer (character based embedding) in 7 iterations (with batch size 4):
-3.2057e-08
-1.0591e-07
8.0309e-10
-3.1149e-08
1.7176e-08
1.0479e-08
-5.9668e-08

loss values:
0.6922
0.6888
0.6932
0.6933
0.705
0.6812
0.7068

first layer parameters (before training):

Parameter containing:
tensor([[-0.8127,  0.0848, -1.8994,  ..., -0.4188,  0.0737,  1.7480],
        [-0.9858,  1.2334, -1.5336,  ..., -0.1520, -0.8097,  1.5319],
        [-0.3637,  0.2356, -0.6203,  ..., -0.2677,  0.3540, -0.8167],
        ...,
        [ 0.5995,  0.0444,  0.5533,  ..., -0.6380, -0.2782,  0.4377],
        [-1.1214,  0.1163,  0.6494,  ...,  0.9082,  0.0925, -2.0435],
        [ 1.1774,  2.0876,  1.2902,  ...,  0.1933,  0.6906, -0.9966]],
       device='cuda:0', requires_grad=True)

first layer parameters (after training on 1000 iterations):

Parameter containing:
tensor([[ 0.4986, -0.1885, -2.1546,  ...,  1.6023,  1.0103, -0.0118],
        [-0.2110, -0.0524, -0.5779,  ..., -1.7709, -0.6997,  1.7685],
        [-0.8088, -0.0187,  0.4958,  ...,  0.2945, -0.8318,  0.5191],
        ...,
        [ 0.0324,  0.6847,  0.7107,  ..., -0.5620,  1.1643, -0.1883],
        [ 0.3290, -1.5829, -1.2789,  ..., -0.6205, -1.9693, -0.8639],
        [ 1.1525,  1.1839,  1.4262,  ...,  0.1396, -0.0622, -1.1427]],
       device='cuda:0', requires_grad=True)

conv1d_embed module (Embedding + Convolution 1D):

class Conv1d_Embed(nn.Module):
  def __init__(self, embed_dim, filters_list):
    super(Conv1d_Embed, self).__init__()
    self.filters_list = filters_list
    self.embed = nn.Embedding(num_embeddings=chars_count, embedding_dim=embed_dim, device=device)
    self.conv_list = nn.ModuleList(modules=None)
    self.conv_norm_layer = nn.LayerNorm([100, np.sum(np.array(self.filters_list)[:, 0])])
    for filter in filters_list:
      conv = nn.Conv1d(in_channels=embed_dim, out_channels=filter[0], kernel_size=filter[1], stride=1, padding=0, dilation=1, device=device)
      self.conv_list.append(conv)
  def forward(self, X):
    X = self.embed(X).permute(0, 1, 3, 2)
    X_conv = torch.empty(size=(X.shape[0], X.shape[1], np.sum(np.array(self.filters_list)[:, 0])))
    for sentence_idx in range(X.shape[0]):
      idx_sum = 0
      for convolution in self.conv_list:
        torch.cuda.empty_cache()
        conv_result = convolution(X[sentence_idx])
        conv_result = torch.max(conv_result, dim=2).values
        seq_columns = convolution.out_channels
        X_conv[sentence_idx][:, idx_sum:idx_sum + seq_columns] = conv_result
        idx_sum += seq_columns
    X_conv = self.conv_norm_layer(X_conv)
    X_conv = torch.relu(X_conv)
    torch.cuda.empty_cache()
    return X_conv

highway network module:

class Highway_Network(nn.Module):
  def __init__(self, H_act:str, in_dim:int):
    super(Highway_Network, self).__init__()
    if H_act == 'relu': self.H_act = nn.ReLU()
    elif H_act == 'tanh': self.H_act = nn.Tanh()
    else: self.H_act = nn.Sigmoid()
    self.in_dim = in_dim
    self.H = nn.Linear(in_features=in_dim, out_features=in_dim, bias=False, device=device)
    self.T = nn.Linear(in_features=in_dim, out_features=in_dim, bias=True, device=device)
  def forward(self, X):
    T = torch.sigmoid(self.T(X))
    H = self.H_act(self.H(X))
    y = (H * T) + (X * (1 - T))
    torch.cuda.empty_cache()
    return y

ELMo module:

class ELMo(nn.Module):
  def __init__(self, in_dim_for_highway, embed_dim, filters_list, proj_size, rnn_hidden_size):
    super(ELMo, self).__init__()
    self.conv1d_embed = Conv1d_Embed(embed_dim, filters_list)
    self.highway_layer1 = Highway_Network(H_act='tanh', in_dim=in_dim_for_highway)
    self.highway_layer2 = Highway_Network(H_act='tanh', in_dim=in_dim_for_highway)
    self.proj_after_highway = nn.Linear(in_features=in_dim_for_highway, out_features=proj_size, bias=True, device=device)
    self.norm_after_highway = nn.LayerNorm([100, proj_size], device=device)
    self.rnn_layer1_forward = nn.GRU(input_size=proj_size, hidden_size=rnn_hidden_size, num_layers=1, bias=True,
                                        batch_first=True, dropout=0, bidirectional=False, device=device)
    self.rnn_layer1_backward = nn.GRU(input_size=proj_size, hidden_size=rnn_hidden_size, num_layers=1, bias=True,
                                        batch_first=True, dropout=0, bidirectional=False, device=device)
    self.rnn_layer2_forward = nn.GRU(input_size=proj_size, hidden_size=rnn_hidden_size, num_layers=1, bias=True,
                                        batch_first=True, dropout=0, bidirectional=False, device=device)
    self.rnn_layer2_backward = nn.GRU(input_size=proj_size, hidden_size=rnn_hidden_size, num_layers=1, bias=True,
                                        batch_first=True, dropout=0, bidirectional=False, device=device)
    self.proj_after_rnn1_forward = nn.Linear(in_features=rnn_hidden_size, out_features=proj_size, bias=True, device=device)
    self.proj_after_rnn1_backward = nn.Linear(in_features=rnn_hidden_size, out_features=proj_size, bias=True, device=device)
    self.proj_after_rnn2_forward = nn.Linear(in_features=rnn_hidden_size, out_features=proj_size, bias=True, device=device)
    self.proj_after_rnn2_backward = nn.Linear(in_features=rnn_hidden_size, out_features=proj_size, bias=True, device=device)
    self.output_layer = nn.Linear(in_features=102400, out_features=1, bias=True, device=device)
  def forward(self, X):
    output = self.conv1d_embed(X).to(device)
    output = self.highway_layer1(output)
    output = self.highway_layer2(output)
    output = self.proj_after_highway(output)
    output = self.norm_after_highway(output)
    output = torch.relu(output)

    forward_output = self.rnn_layer1_forward(output)[0] # forward
    forward_output = torch.relu(forward_output)
    forward_output = self.proj_after_rnn1_forward(forward_output)
    forward_output = torch.relu(forward_output)

    backward_output = self.rnn_layer1_backward(torch.flip(output, dims=[1]))[0] # backward
    backward_output = torch.relu(backward_output)
    backward_output = self.proj_after_rnn1_backward(backward_output)
    backward_output = torch.relu(backward_output)

    forward_output = self.rnn_layer2_forward(forward_output)[0]
    forward_output = torch.relu(forward_output)
    forward_output = self.proj_after_rnn2_forward(forward_output)
    forward_output = torch.relu(forward_output)

    backward_output = self.rnn_layer2_backward(backward_output)[0]
    backward_output = torch.relu(backward_output)
    backward_output = self.proj_after_rnn2_backward(backward_output)
    backward_output = torch.relu(backward_output)
    backward_output = torch.flip(backward_output, dims=[1])

    output = torch.concat((forward_output, backward_output), dim=2)
    output = output.reshape((output.shape[0], output.shape[1] * output.shape[2]))
    output = self.output_layer(output)
    output = torch.sigmoid(output)
    return output

some other details:

embed_dim = 50
model_location = 'drive/MyDrive/elmo_dataset_words_lower_100/elmo_model.mdl'
optimizer_location = 'drive/MyDrive/elmo_dataset_words_lower_100/elmo_optimizer.optm'
filters_list = [[32, 1], [32, 2], [64, 3], [128, 4], [256, 5], [512, 6], [1024, 7]]
in_dim_for_highway = np.sum(np.array(filters_list)[:, 0])
proj_size = 512
rnn_hidden_size = 4096

Feedforward + Backward module:

model = ELMo(in_dim_for_highway, embed_dim, filters_list, proj_size, rnn_hidden_size)
optimizer = optim.Adam(params=model.parameters(), lr=1e-5)

# model.load_state_dict(torch.load(model_location))
# optimizer.load_state_dict(torch.load(optimizer_location))
print(summary(model))

batch_size = 4
epochs = 5 # Started by 5
bce = nn.BCELoss()
new_slices = slices = pd.read_csv('drive/MyDrive/elmo_dataset_words_lower_100/slice_list.csv').drop(columns=['Unnamed: 0']) # slice 10 is for test

for slice_idx in range(len(slices)):
  slice_path = slices.iloc[slice_idx, :].values[0]
  print(f'Training ELMo on {slice_path}...')
  dataset = np.load(slice_path)
  labels = torch.Tensor(dataset['labels'].astype(np.float32)).to('cpu')
  dataset = torch.Tensor(dataset['data']).type(torch.int32).to('cpu')
  for label_idx in range(len(labels)):
    if labels[label_idx] == -1: labels[label_idx] = 0
    # elif labels[label_idx] == 0: labels[label_idx] = 1
    elif labels[label_idx] == 1: labels[label_idx] = 1
  dataset_size = dataset.shape[0]
  dataset_loss = list()
  idx = torch.randperm(dataset.shape[0])
  dataset = dataset[idx] # Randomization
  labels = labels[idx] # Randomization
  for batch in range(batch_size, dataset.shape[0] + batch_size, batch_size):
    optimizer.zero_grad()
    X = dataset[batch - batch_size:batch].to(device)
    y = labels[batch - batch_size:batch].to(device)
    output = model(X).squeeze()
    loss = bce(output, y)
    loss.backward()
    optimizer.step()
    print(torch.mean(list(model.parameters())[0].grad))
    loss_value = loss.item()
    dataset_loss.append(loss_value)
    print(f'Batch: {batch} - Loss: {loss_value} - Dataset size: {dataset_size}')
  print('---------------------')
  torch.save(model.state_dict(), model_location)
  torch.save(optimizer.state_dict(), optimizer_location)
  print(f'Dataset slice: {slice_path} - Loss: {np.mean(dataset_loss)}')
  print(f'Trained model saved in {model_location}')
  print(f'Optimizer saved in {optimizer_location}')
  print('---------------------')
  new_slices = new_slices.drop(index=slice_idx)
  new_slices.to_csv('drive/MyDrive/elmo_dataset_words_lower_100/slice_list.csv')
  del X, y, dataset, labels, output
  collect()

I examined every hyper-parameter you think (batch size, learning rate, activation functions, projection size and etc) and checked labels.
What is problem? I think there is mistake in using pytorch modules like autograd…

I printed grad value for weights of every layer (with batch size 16) and guess that vanishing gradient is happening. I mean, when we far away from last layer (output layer), grad values are closing to zero:

print(f'embed: {torch.sum(list(model.conv1d_embed.parameters())[1].grad)}')
print(f'conv1: {torch.sum(list(model.conv1d_embed.parameters())[3].grad)}')
print(f'conv2: {torch.sum(list(model.conv1d_embed.parameters())[5].grad)}')
print(f'conv3: {torch.sum(list(model.conv1d_embed.parameters())[7].grad)}')
print(f'conv4: {torch.sum(list(model.conv1d_embed.parameters())[9].grad)}')
print(f'conv5: {torch.sum(list(model.conv1d_embed.parameters())[11].grad)}')
print(f'highway1: {torch.sum(list(model.highway_layer1.parameters())[0].grad)}')
print(f'highway2: {torch.sum(list(model.highway_layer2.parameters())[0].grad)}')
print(f'norm_after_highway: {torch.sum(list(model.norm_after_highway.parameters())[0].grad)}')
print(f'proj_after_highway: {torch.sum(list(model.proj_after_highway.parameters())[0].grad)}')
print(f'rnn_layer1_backward: {torch.sum(list(model.rnn_layer1_backward.parameters())[0].grad)}')
print(f'proj_after_rnn1_backward: {torch.sum(list(model.proj_after_rnn1_backward.parameters())[0].grad)}')
print(f'rnn_layer2_backward: {torch.sum(list(model.rnn_layer2_backward.parameters())[0].grad)}')
print(f'proj_after_rnn2_backward: {torch.sum(list(model.proj_after_rnn2_backward.parameters())[0].grad)}')
print(f'output: {torch.sum(list(model.output_layer.parameters())[0].grad)}')
print(f'output_bias: {torch.sum(list(model.output_layer.parameters())[1].grad)}')
Training ELMo on drive/MyDrive/elmo_dataset_words_lower_100/slice4.npz...
embed: -4.6593424485763535e-05
conv1: 3.8413108995882794e-05
conv2: 0.0007232016650959849
conv3: 0.00017461027891840786
conv4: 0.0025903673376888037
conv5: -0.0006756336661055684
highway1: -0.016100123524665833
highway2: -0.1111525222659111
norm_after_highway: -6.979409954510629e-05
proj_after_highway: 0.015135683119297028
rnn_layer1_backward: -0.13128359615802765
proj_after_rnn1_backward: -0.2215435951948166
rnn_layer2_backward: 0.06897228956222534
proj_after_rnn2_backward: 1.4058352708816528
output: -62.17420959472656
output_bias: -0.12514057755470276
-------------
Batch: 16 - Loss: 0.6932045221328735 - Dataset size: 20000
embed: 4.096475458936766e-05
conv1: -0.00011880289821419865
conv2: -5.4191244998946786e-05
conv3: -0.0008076142403297126
conv4: 0.0006381101557053626
conv5: -0.0010407972149550915
highway1: -0.03469737619161606
highway2: -0.04613765701651573
norm_after_highway: -0.0007416537264361978
proj_after_highway: -0.029202647507190704
rnn_layer1_backward: -0.4190102219581604
proj_after_rnn1_backward: -0.9112290143966675
rnn_layer2_backward: -0.2532331347465515
proj_after_rnn2_backward: 0.22813838720321655
output: -60.597957611083984
output_bias: -0.12170735001564026
-------------
Batch: 32 - Loss: 0.6914718151092529 - Dataset size: 20000
embed: -6.607886462006718e-05
conv1: -0.00018568531959317625
conv2: -0.00043446820927783847
conv3: -0.0013823093613609672
conv4: 0.0003382552822586149
conv5: -0.0029913655016571283
highway1: -0.05714196339249611
highway2: -0.35561251640319824
norm_after_highway: -0.002709394320845604
proj_after_highway: 0.01186932623386383
rnn_layer1_backward: -1.5879170894622803
proj_after_rnn1_backward: -3.678624153137207
rnn_layer2_backward: -1.0227971076965332
proj_after_rnn2_backward: -1.8797272443771362
output: -122.38917541503906
output_bias: -0.24342972040176392
-------------
Batch: 48 - Loss: 0.6866424679756165 - Dataset size: 20000
embed: -3.808040128205903e-05
conv1: -2.0495388525887392e-05
conv2: 0.0001296251139137894
conv3: -0.0004166905127931386
conv4: 7.206742884591222e-05
conv5: -0.0015585473738610744
highway1: -0.016968537122011185
highway2: -0.06996846199035645
norm_after_highway: -0.0008239669259637594
proj_after_highway: 0.07070107012987137
rnn_layer1_backward: -0.4688308835029602
proj_after_rnn1_backward: -1.1465927362442017
rnn_layer2_backward: -0.3769247829914093
proj_after_rnn2_backward: -0.7544925212860107
output: -26.983747482299805
output_bias: -0.05277787894010544
-------------
Batch: 64 - Loss: 0.6909307241439819 - Dataset size: 20000
embed: 0.00012138157762819901
conv1: -0.0001396716688759625
conv2: 0.00018317438662052155
conv3: -0.0006879045977257192
conv4: 0.0005162524175830185
conv5: -0.00472242571413517
highway1: -0.006774289533495903
highway2: -0.12450309097766876
norm_after_highway: -0.003364659147337079
proj_after_highway: -0.003951713442802429
rnn_layer1_backward: -1.903158187866211
proj_after_rnn1_backward: -4.582720756530762
rnn_layer2_backward: -1.7180144786834717
proj_after_rnn2_backward: -3.437438488006592
output: -91.4657211303711
output_bias: -0.17489370703697205
-------------
Batch: 80 - Loss: 0.683960497379303 - Dataset size: 20000
embed: 2.139205207640771e-06
conv1: 6.419386772904545e-05
conv2: 1.276363036595285e-06
conv3: 5.8115772844757885e-05
conv4: 0.00027716817567124963
conv5: -4.813261330127716e-05
highway1: -0.01209902111440897
highway2: 0.02906982973217964
norm_after_highway: 0.000386736704967916
proj_after_highway: -0.08745254576206207
rnn_layer1_backward: 0.19513759016990662
proj_after_rnn1_backward: 0.465777724981308
rnn_layer2_backward: 0.23875851929187775
proj_after_rnn2_backward: 0.35944581031799316
output: 8.396147727966309
output_bias: 0.015698343515396118
-------------
Batch: 96 - Loss: 0.6936017870903015 - Dataset size: 20000
embed: -1.0425923392176628e-05
conv1: 6.876498082419857e-05
conv2: 7.341942546190694e-05
conv3: 0.00045036099618300796
conv4: 0.0002511006605345756
conv5: 0.00017986627062782645
highway1: -0.023883046582341194
highway2: 0.036179639399051666
norm_after_highway: 0.000519546156283468
proj_after_highway: -0.12831953167915344
rnn_layer1_backward: 0.2911137342453003
proj_after_rnn1_backward: 0.6544842720031738
rnn_layer2_backward: 0.31721436977386475
proj_after_rnn2_backward: 0.7356137037277222
output: 10.009611129760742
output_bias: 0.018242139369249344
-------------
Batch: 112 - Loss: 0.6937789916992188 - Dataset size: 20000
embed: 0.00016911677084863186
conv1: -3.84192171622999e-05
conv2: 0.0005247838562354445
conv3: -0.0006903375033289194
conv4: -0.00027979101287201047
conv5: -0.0020676155108958483
highway1: -0.08901641517877579
highway2: -0.3113752603530884
norm_after_highway: -0.004887045361101627
proj_after_highway: 0.030139975249767303
rnn_layer1_backward: -2.424699306488037
proj_after_rnn1_backward: -5.636422157287598
rnn_layer2_backward: -2.6509008407592773
proj_after_rnn2_backward: -6.687723636627197
output: -94.3984603881836
output_bias: -0.16710714995861053
-------------

Now I’m using residual connection between some layers and last problem is almost solved! but there’s a new problem. As you can see in bottom results, grad value for last layer (output layer) is very very high:

Training ELMo on drive/MyDrive/elmo_dataset_words_lower_100/slice2.npz...
embed: -0.31124985218048096
conv1: -0.23671860992908478
conv2: 0.9753038883209229
conv3: 1.0003304481506348
conv4: 0.2072356939315796
conv5: -0.9559061527252197
proj_before_highway: -136.29574584960938
highway: -0.25242868065834045
rnn_layer1_backward: 4.168163776397705
proj_after_rnn1_backward: -0.29323554039001465
rnn_layer2_backward: 0.6222412586212158
proj_after_rnn2_backward: -2.134066343307495
output: -6835.08984375
output_bias: -0.38621053099632263
-------------
Batch: 16 - Loss: 0.7886543273925781 - Dataset size: 20000
embed: 0.21649760007858276
conv1: -0.009094476699829102
conv2: -0.4256786108016968
conv3: 0.9828163385391235
conv4: -2.033944606781006
conv5: -0.47921907901763916
proj_before_highway: 1298.85107421875
highway: 2.451516628265381
rnn_layer1_backward: 2.9442484378814697
proj_after_rnn1_backward: 2.2511916160583496
rnn_layer2_backward: 7.089057922363281
proj_after_rnn2_backward: 5.113781929016113
output: 5165.3310546875
output_bias: 0.24871262907981873
-------------
Batch: 32 - Loss: 1.743406057357788 - Dataset size: 20000
embed: 0.40841901302337646
conv1: 0.0024018846452236176
conv2: -1.0121322870254517
conv3: 1.370057463645935
conv4: -4.331480979919434
conv5: -1.6278128623962402
proj_before_highway: 2444.46630859375
highway: -0.4508698582649231
rnn_layer1_backward: 0.47323620319366455
proj_after_rnn1_backward: 0.8400318026542664
rnn_layer2_backward: 10.401853561401367
proj_after_rnn2_backward: 3.7883834838867188
output: 9563.591796875
output_bias: 0.5599931478500366
-------------
Batch: 48 - Loss: 3.407438278198242 - Dataset size: 20000
embed: 0.3553653955459595
conv1: 0.09798091650009155
conv2: -1.2439923286437988
conv3: 0.7889980673789978
conv4: -2.6324338912963867
conv5: 0.8467768430709839
proj_before_highway: 968.948974609375
highway: -2.5037989616394043
rnn_layer1_backward: -3.0111083984375
proj_after_rnn1_backward: -0.030773989856243134
rnn_layer2_backward: 2.8658604621887207
proj_after_rnn2_backward: 0.13935384154319763
output: 6282.17822265625
output_bias: 0.5004515647888184
-------------
Batch: 64 - Loss: 1.6010830402374268 - Dataset size: 20000
embed: 0.014941547065973282
conv1: -0.016871655359864235
conv2: 0.09186001121997833
conv3: 0.0824129730463028
conv4: -0.07007266581058502
conv5: -0.4893345534801483
proj_before_highway: -2.7973151206970215
highway: -0.07115613669157028
rnn_layer1_backward: 0.15298619866371155
proj_after_rnn1_backward: 0.035967789590358734
rnn_layer2_backward: 0.07593191415071487
proj_after_rnn2_backward: -0.04055314138531685
output: -174.4060821533203
output_bias: -0.01941879838705063
-------------
Batch: 80 - Loss: 0.6552839875221252 - Dataset size: 20000
embed: -0.1801912486553192
conv1: -0.49067553877830505
conv2: 2.368882417678833
conv3: 1.0036269426345825
conv4: -0.852114200592041
conv5: -2.5832297801971436
proj_before_highway: 1474.71923828125
highway: -11.867355346679688
rnn_layer1_backward: 2.614063262939453
proj_after_rnn1_backward: 0.8170901536941528
rnn_layer2_backward: 1.628517746925354
proj_after_rnn2_backward: 1.8041082620620728
output: -4941.58740234375
output_bias: -0.6040517091751099
-------------
Batch: 96 - Loss: 1.7138996124267578 - Dataset size: 20000
embed: -0.03272230923175812
conv1: -0.27587825059890747
conv2: 1.1577472686767578
conv3: 0.5143566131591797
conv4: -1.0061132907867432
conv5: -1.739422082901001
proj_before_highway: 731.8646240234375
highway: -6.7713117599487305
rnn_layer1_backward: -0.15281789004802704
proj_after_rnn1_backward: 0.04255335405468941
rnn_layer2_backward: 0.5587131977081299
proj_after_rnn2_backward: 0.4053055942058563
output: -1328.992919921875
output_bias: -0.30003494024276733
-------------
Batch: 112 - Loss: 0.9406654834747314 - Dataset size: 20000
embed: -0.11346188932657242
conv1: -0.2776741683483124
conv2: 0.9385225772857666
conv3: 0.9329437613487244
conv4: -0.36790353059768677
conv5: -2.7052650451660156
proj_before_highway: 853.1502685546875
highway: -6.701651096343994
rnn_layer1_backward: -0.39196327328681946
proj_after_rnn1_backward: -0.17777833342552185
rnn_layer2_backward: 0.21363681554794312
proj_after_rnn2_backward: 0.349449098110199
output: -755.6411743164062
output_bias: -0.34320610761642456
-------------
Batch: 128 - Loss: 0.9142484664916992 - Dataset size: 20000
embed: -0.01170092448592186
conv1: 0.05439986288547516
conv2: 0.07825899124145508
conv3: 0.15789179503917694
conv4: -0.23841612040996552
conv5: -0.42686524987220764
proj_before_highway: 14.587163925170898
highway: -0.06879334151744843
rnn_layer1_backward: -0.08501295745372772
proj_after_rnn1_backward: 0.0024774931371212006
rnn_layer2_backward: -0.09783315658569336
proj_after_rnn2_backward: 0.058680903166532516
output: 104.94813537597656
output_bias: -0.003501247614622116
-------------
Batch: 144 - Loss: 0.643433153629303 - Dataset size: 20000
embed: 0.010831267572939396
conv1: 0.07044383883476257
conv2: -0.49638843536376953
conv3: -0.9110832214355469
conv4: 0.6502820253372192
conv5: 5.40472412109375
proj_before_highway: -393.1110534667969
highway: -3.860703229904175
rnn_layer1_backward: -0.12032938003540039
proj_after_rnn1_backward: -0.21743673086166382
rnn_layer2_backward: 0.8317618370056152
proj_after_rnn2_backward: 0.9559182524681091
output: -688.6527709960938
output_bias: 0.29738783836364746
-------------
Batch: 160 - Loss: 1.0594557523727417 - Dataset size: 20000
embed: -0.047927286475896835
conv1: 0.027734097093343735
conv2: -0.16607719659805298
conv3: -1.6493910551071167
conv4: 0.009051352739334106
conv5: 5.373091697692871
proj_before_highway: -562.8831787109375
highway: -3.9806060791015625
rnn_layer1_backward: 0.4737078845500946
proj_after_rnn1_backward: -0.06469704955816269
rnn_layer2_backward: 0.351388543844223
proj_after_rnn2_backward: -0.09641136229038239
output: -808.8811645507812
output_bias: 0.35645124316215515
-------------
Batch: 176 - Loss: 1.0919063091278076 - Dataset size: 20000
embed: -0.06440611928701401
conv1: 0.01709364727139473
conv2: -0.07498355209827423
conv3: -0.35806286334991455
conv4: 0.08119155466556549
conv5: 1.282605528831482
proj_before_highway: -30.3171443939209
highway: -1.9120361804962158
rnn_layer1_backward: 0.2095092535018921
proj_after_rnn1_backward: 0.06088166683912277
rnn_layer2_backward: 0.07566089928150177
proj_after_rnn2_backward: -0.012311633676290512
output: -355.6859130859375
output_bias: 0.05726774409413338
-------------
Batch: 192 - Loss: 0.6065136790275574 - Dataset size: 20000
embed: -0.009362181648612022
conv1: 0.028817325830459595
conv2: -0.10502660274505615
conv3: -0.716947078704834
conv4: 0.12533381581306458
conv5: 2.081275463104248
proj_before_highway: -144.4345245361328
highway: -1.598570466041565
rnn_layer1_backward: 0.9062970876693726
proj_after_rnn1_backward: 0.26810312271118164
rnn_layer2_backward: 0.23268043994903564
proj_after_rnn2_backward: -0.3262060284614563
output: -586.6190795898438
output_bias: 0.12614455819129944
-------------
Batch: 208 - Loss: 0.7223302125930786 - Dataset size: 20000
embed: 0.0059539545327425
conv1: -0.0004953863099217415
conv2: -0.008280763402581215
conv3: 0.2739175856113434
conv4: 0.26986101269721985
conv5: 0.1452290117740631
proj_before_highway: 95.26295471191406
highway: -0.4947985112667084
rnn_layer1_backward: -0.46712470054626465
proj_after_rnn1_backward: -0.14849968254566193
rnn_layer2_backward: -0.05727086216211319
proj_after_rnn2_backward: 0.2119525671005249
output: 178.8102264404297
output_bias: -0.05756787210702896

Why is this happening? I used gradient-clipping, but other grad value for other layers decrease too.

Hi.
Finally I found a solution.
Mentioned ELMo architecture in original paper is very big! in other words, we need huge dataset to train it. Number of my samples was about 160.000 (number of original elmo parameters is almost 300 million) so I changed architecture and created new model with low number of parameters and problem get solved.