When I make Attention MODEL, I can not handle LOSS Function Could you help me?

Here is my Model

class Build_Model(nn.Module):
    def __init__(self,args) : 
    super(Build_Model, self).__init__()
    self.hidden_size = args.dec_size
    self.embedding = nn.Embedding(args.n_vocab, args.d_model)
    self.enc_lstm = nn.LSTM(input_size =args.d_model, hidden_size=args.d_model,batch_first=True)
    self.dec_lstm = nn.LSTM(input_size =args.d_model, hidden_size=args.d_model,batch_first=True)
    self.soft_prob = nn.Softmax(dim=-1)
    self.softmax_linear = nn.Linear(args.d_model*2,len(vocab))
    self.softmax_linear_function = nn.Softmax(dim = -1)

def forward(self, enc_inputs, dec_inputs) : 
    enc_hidden = self.embedding(enc_inputs)
    dec_hidden = self.embedding(dec_inputs)
    enc_hidden , (enc_h_state,enc_c_state) = self.enc_lstm(enc_hidden)
    dec_hidden,(dec_h_state,dec_c_state) = self.dec_lstm(dec_hidden,(enc_h_state,enc_c_state))
    attn_score = torch.matmul(dec_hidden, torch.transpose(enc_hidden,2,1))
    attn_prob  = self.soft_prob(attn_score)
    attn_out = torch.matmul(attn_prob,enc_hidden)
    cat_hidden = torch.cat((attn_out, dec_hidden),-1)
    y_pred = self.softmax_linear_function(self.softmax_linear(cat_hidden))
    y_pred = torch.argmax(y_pred,dim =-1)
    print('y_pred = ',y_pred.shape)
    y_pred = y_pred.view(-1, 150)
    print('2y_pred = ',y_pred.shape)
    return y_pred

Here is the LOSS function

def lm_loss(y_true, y_pred):
    print(y_pred.shape)
    y_pred_argmax = y_pred
    #y_pred_argmax = y_pred_argmax.view(-1,150)
    print(y_true.shape, y_pred_argmax.shape)
    criterion = nn.CrossEntropyLoss(reduction="none")
    loss = criterion(y_true.float(), y_pred_argmax.float()[0])    
    #mask = tf.not_equal(y_true, 0)
    mask = torch.not_equal(y_pred_argmax,0)
    #mask = tf.cast(mask, tf.float32)
    mask = mask.type(torch.FloatTensor).to(device)
    loss *= mask
    #loss = tf.reduce_sum(loss) / tf.maximum(tf.reduce_sum(mask), 1)
    loss = torch.sum(loss) / torch.maximum(torch.sum(mask),1)
    return loss

The last is evaluation

optimizer.zero_grad()
print(train_enc_inputs.shape,train_dec_inputs.shape, train_dec_labels.shape )
y_pred = model(train_enc_inputs,train_dec_inputs)
#y_pred = torch.argmax(y_pred,dim =-1)
print(y_pred.shape )
loss = lm_loss(train_dec_labels, y_pred)

OUTPUT

torch.Size([32, 120]) torch.Size([32, 150]) torch.Size([32, 150])
y_pred = torch.Size([32, 150])
2y_pred = torch.Size([32, 150])
torch.Size([32, 150])
torch.Size([32, 150])
torch.Size([32, 150]) torch.Size([32, 150])

ERROR How can I fix it?

ValueError Traceback (most recent call last) in ()

      9     #y_pred = torch.argmax(y_pred,dim =-1)
     10     print(y_pred.shape )
---> 11     loss = lm_loss(train_dec_labels, y_pred)
     12     n_step += 1
     13     if n_step % 10 == 0:

3 frames

in lm_loss(y_true, y_pred)

     15     print(y_true.shape, y_pred_argmax.shape)
     16     criterion = nn.CrossEntropyLoss(reduction="none")
---> 17     loss = criterion(y_true.float(), y_pred_argmax.float()[0])
     18     #mask = tf.not_equal(y_true, 0)
     19     mask = torch.not_equal(y_pred_argmax,0)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)

   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)

   1119     def forward(self, input: Tensor, target: Tensor) -> Tensor:
   1120         return F.cross_entropy(input, target, weight=self.weight,
-> 1121                                ignore_index=self.ignore_index, reduction=self.reduction)
   1122 
   1123

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)

   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826

ValueError: Expected input batch_size (32) to match target batch_size (150).